Skip to content

Commit

Permalink
update duckdb, add macro to lateral join vector tables
Browse files Browse the repository at this point in the history
  • Loading branch information
Maxxen committed Aug 12, 2024
1 parent 9ff608f commit 067cfb1
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 8 deletions.
2 changes: 1 addition & 1 deletion duckdb
Submodule duckdb updated 2266 files
1 change: 1 addition & 0 deletions src/hnsw/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ set(EXTENSION_SOURCES
${EXTENSION_SOURCES}
${CMAKE_CURRENT_SOURCE_DIR}/hnsw_index.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hnsw_index_logical_create.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hnsw_index_macros.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hnsw_index_physical_create.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hnsw_index_pragmas.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hnsw_index_scan.cpp
Expand Down
2 changes: 1 addition & 1 deletion src/hnsw/hnsw_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace duckdb {

class LinkedBlock {
public:
static constexpr const idx_t BLOCK_SIZE = Storage::BLOCK_SIZE - sizeof(validity_t);
static constexpr const idx_t BLOCK_SIZE = Storage::DEFAULT_BLOCK_SIZE - sizeof(validity_t);
static constexpr const idx_t BLOCK_DATA_SIZE = BLOCK_SIZE - sizeof(IndexPointer);
static_assert(BLOCK_SIZE > sizeof(IndexPointer), "Block size must be larger than the size of an IndexPointer");

Expand Down
119 changes: 119 additions & 0 deletions src/hnsw/hnsw_index_macros.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
#include "duckdb/function/table_macro_function.hpp"
#include "duckdb/main/extension_util.hpp"
#include "duckdb/optimizer/matcher/expression_matcher.hpp"
#include "hnsw/hnsw.hpp"
#include "hnsw/hnsw_index.hpp"
#include "duckdb/parser/parser.hpp"

namespace duckdb {

static constexpr auto VSS_JOIN_MACRO = R"(
SELECT
score,
left_tbl,
right_tbl,
FROM
(SELECT * FROM query_table(left_table::VARCHAR)) as left_tbl,
(
SELECT
struct_pack(*columns([x for x in (matches.*) if x != 'score'])) as right_tbl,
matches.score as score
FROM (
SELECT (
unnest(
CASE WHEN metric = 'l2sq' OR metric = 'l2'
THEN min_by(tbl, tbl.score, k)
ELSE max_by(tbl, tbl.score, k)
END,
max_depth := 2
)
) as result
FROM (
SELECT
*,
CASE
WHEN metric = 'l2sq' OR metric = 'l2'
THEN array_distance(left_col, right_col)
WHEN metric = 'cosine' OR metric = 'cos'
THEN array_cosine_similarity(left_col, right_col)
WHEN metric = 'ip'
THEN array_inner_product(left_col, right_col)
ELSE error('Unknown metric')
END as score,
FROM query_table(right_table::VARCHAR)
) as tbl
) as matches
)
)";

static constexpr auto VSS_MATCH_MACRO = R"(
SELECT
right_tbl as matches,
FROM
(
SELECT (
CASE WHEN metric = 'l2sq' OR metric = 'l2'
THEN min_by({'score': score, 'row': t}, score, k)
ELSE max_by({'score': score, 'row': t}, score, k)
END
) as right_tbl
FROM (
SELECT
CASE
WHEN metric = 'l2sq' OR metric = 'l2'
THEN array_distance(left_col, right_col)
WHEN metric = 'cosine' OR metric = 'cos'
THEN array_cosine_similarity(left_col, right_col)
WHEN metric = 'ip'
THEN array_inner_product(left_col, right_col)
ELSE error('Unknown metric')
END as score,
tbl as t,
FROM (SELECT * FROM query_table(right_table::VARCHAR)) as tbl
)
)
)";

//-------------------------------------------------------------------------
// Register
//-------------------------------------------------------------------------
static void RegisterTableMacro(DatabaseInstance &db, const string &name, const string &query,
const vector<string> &params, const child_list_t<Value> &named_params) {

Parser parser;
parser.ParseQuery(query);
const auto &stmt = parser.statements.back();
auto &node = stmt->Cast<SelectStatement>().node;

auto func = make_uniq<TableMacroFunction>(std::move(node));
for (auto &param : params) {
func->parameters.push_back(make_uniq<ColumnRefExpression>(param));
}

for (auto &param : named_params) {
func->default_parameters[param.first] = make_uniq<ConstantExpression>(param.second);
}

CreateMacroInfo info(CatalogType::TABLE_MACRO_ENTRY);
info.schema = DEFAULT_SCHEMA;
info.name = name;
info.temporary = true;
info.internal = true;
info.macros.push_back(std::move(func));

ExtensionUtil::RegisterFunction(db, info);
}

void HNSWModule::RegisterMacros(DatabaseInstance &db) {

RegisterTableMacro(db, "vss_join", VSS_JOIN_MACRO,
{"left_table", "right_table", "left_col", "right_col", "k"},
{{"metric", Value("l2sq")}});

RegisterTableMacro(db, "vss_match", VSS_MATCH_MACRO,
{"right_table", "left_col", "right_col", "k"},
{{"metric", Value("l2sq")}});

}

} // namespace duckdb
13 changes: 7 additions & 6 deletions src/hnsw/hnsw_index_physical_create.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ SinkCombineResultType PhysicalCreateHNSWIndex::Combine(ExecutionContext &context
class HNSWIndexConstructTask final : public ExecutorTask {
public:
HNSWIndexConstructTask(shared_ptr<Event> event_p, ClientContext &context, CreateHNSWIndexGlobalState &gstate_p,
size_t thread_id_p)
: ExecutorTask(context, std::move(event_p)), gstate(gstate_p), thread_id(thread_id_p), local_scan_state() {
size_t thread_id_p, const PhysicalCreateHNSWIndex &op_p)
: ExecutorTask(context, std::move(event_p), op_p), gstate(gstate_p), thread_id(thread_id_p), local_scan_state() {
// Initialize the scan chunk
gstate.collection->InitializeScanChunk(scan_chunk);
}
Expand Down Expand Up @@ -209,11 +209,12 @@ class HNSWIndexConstructTask final : public ExecutorTask {

class HNSWIndexConstructionEvent final : public BasePipelineEvent {
public:
HNSWIndexConstructionEvent(CreateHNSWIndexGlobalState &gstate_p, Pipeline &pipeline_p, CreateIndexInfo &info_p,
HNSWIndexConstructionEvent(const PhysicalCreateHNSWIndex &op_p, CreateHNSWIndexGlobalState &gstate_p, Pipeline &pipeline_p, CreateIndexInfo &info_p,
const vector<column_t> &storage_ids_p, DuckTableEntry &table_p)
: BasePipelineEvent(pipeline_p), gstate(gstate_p), info(info_p), storage_ids(storage_ids_p), table(table_p) {
: BasePipelineEvent(pipeline_p), op(op_p), gstate(gstate_p), info(info_p), storage_ids(storage_ids_p), table(table_p) {
}

const PhysicalCreateHNSWIndex &op;
CreateHNSWIndexGlobalState &gstate;
CreateIndexInfo &info;
const vector<column_t> &storage_ids;
Expand All @@ -229,7 +230,7 @@ class HNSWIndexConstructionEvent final : public BasePipelineEvent {

vector<shared_ptr<Task>> construct_tasks;
for (size_t tnum = 0; tnum < num_threads; tnum++) {
construct_tasks.push_back(make_uniq<HNSWIndexConstructTask>(shared_from_this(), context, gstate, tnum));
construct_tasks.push_back(make_uniq<HNSWIndexConstructTask>(shared_from_this(), context, gstate, tnum, op));
}
SetTasks(std::move(construct_tasks));
}
Expand Down Expand Up @@ -295,7 +296,7 @@ SinkFinalizeType PhysicalCreateHNSWIndex::Finalize(Pipeline &pipeline, Event &ev
collection->InitializeScan(gstate.scan_state, ColumnDataScanProperties::ALLOW_ZERO_COPY);

// Create a new event that will construct the index
auto new_event = make_shared_ptr<HNSWIndexConstructionEvent>(gstate, pipeline, *info, storage_ids, table);
auto new_event = make_shared_ptr<HNSWIndexConstructionEvent>(*this, gstate, pipeline, *info, storage_ids, table);
event.InsertEvent(std::move(new_event));

return SinkFinalizeType::READY;
Expand Down
2 changes: 2 additions & 0 deletions src/include/hnsw/hnsw.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ struct HNSWModule {
RegisterIndexPragmas(db);
RegisterPlanIndexScan(db);
RegisterPlanIndexCreate(db);
RegisterMacros(db);
}

private:
Expand All @@ -20,6 +21,7 @@ struct HNSWModule {
static void RegisterIndexPragmas(DatabaseInstance &db);
static void RegisterPlanIndexScan(DatabaseInstance &db);
static void RegisterPlanIndexCreate(DatabaseInstance &db);
static void RegisterMacros(DatabaseInstance &db);
};

} // namespace duckdb
19 changes: 19 additions & 0 deletions test/sql/hnsw/hnsw_join_macro.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
require vss

statement ok
CREATE TABLE t1 (id int, vec FLOAT[3]);

statement ok
INSERT INTO t1 SELECT row_number() over (), array_value(a,b,c) FROM range(1,10) ra(a), range(1,10) rb(b), range(1,10) rc(c);

statement ok
CREATE TABLE s(s_vec FLOAT[3]);

statement ok
INSERT INTO s VALUES ([5,5,5]), ([1,1,1]);

query III
SELECT len(matches) = 3 FROM s, vss_match(t1, s_vec, vec, 3) as res;
----
true
true

0 comments on commit 067cfb1

Please sign in to comment.