diff --git a/Makefile b/Makefile index 45dfde5..b6af8fa 100644 --- a/Makefile +++ b/Makefile @@ -67,6 +67,13 @@ release: cmake $(GENERATOR) $(BUILD_FLAGS) $(CLIENT_FLAGS) -DCMAKE_BUILD_TYPE=Release -S ./duckdb/ -B build/release && \ cmake --build build/release --config Release + +reldebug: + mkdir -p build/reldebug && \ + cmake $(GENERATOR) $(BUILD_FLAGS) $(CLIENT_FLAGS) -DCMAKE_BUILD_TYPE=RelWithDebInfo -S ./duckdb/ -B build/reldebug && \ + cmake --build build/reldebug --config RelWithDebInfo + + ##### Client build JS_BUILD_FLAGS=-DBUILD_NODE=1 -DDUCKDB_EXTENSION_${EXTENSION_NAME}_SHOULD_LINK=0 PY_BUILD_FLAGS=-DBUILD_PYTHON=1 -DDUCKDB_EXTENSION_${EXTENSION_NAME}_SHOULD_LINK=0 diff --git a/src/hnsw/hnsw_index.cpp b/src/hnsw/hnsw_index.cpp index 05271b2..944b64a 100644 --- a/src/hnsw/hnsw_index.cpp +++ b/src/hnsw/hnsw_index.cpp @@ -444,8 +444,16 @@ ErrorData HNSWIndex::Insert(IndexLock &lock, DataChunk &input, Vector &rowid_vec return ErrorData {}; } -ErrorData HNSWIndex::Append(IndexLock &lock, DataChunk &entries, Vector &rowid_vec) { - Construct(entries, rowid_vec, unum::usearch::index_dense_t::any_thread()); +ErrorData HNSWIndex::Append(IndexLock &lock, DataChunk &appended_data, Vector &row_identifiers) { + DataChunk expression_result; + expression_result.Initialize(Allocator::DefaultAllocator(), logical_types); + + // first resolve the expressions for the index + ExecuteExpressions(appended_data, expression_result); + + // now insert into the index + Construct(expression_result, row_identifiers, unum::usearch::index_dense_t::any_thread()); + return ErrorData {}; } diff --git a/src/hnsw/hnsw_plan_index_scan.cpp b/src/hnsw/hnsw_plan_index_scan.cpp index fe4f9d4..6bef6be 100644 --- a/src/hnsw/hnsw_plan_index_scan.cpp +++ b/src/hnsw/hnsw_plan_index_scan.cpp @@ -1,15 +1,17 @@ -#include "duckdb/optimizer/optimizer_extension.hpp" #include "duckdb/catalog/catalog_entry/duck_table_entry.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/optimizer/column_lifetime_analyzer.hpp" +#include "duckdb/optimizer/optimizer_extension.hpp" #include "duckdb/planner/expression/bound_constant_expression.hpp" -#include "duckdb/planner/operator/logical_top_n.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" #include "duckdb/planner/operator/logical_get.hpp" #include "duckdb/planner/operator/logical_projection.hpp" +#include "duckdb/planner/operator/logical_top_n.hpp" #include "duckdb/storage/data_table.hpp" - #include "hnsw/hnsw.hpp" #include "hnsw/hnsw_index.hpp" #include "hnsw/hnsw_index_scan.hpp" +#include "duckdb/optimizer/remove_unused_columns.hpp" +#include "duckdb/planner/expression_iterator.hpp" namespace duckdb { @@ -22,12 +24,12 @@ class HNSWIndexScanOptimizer : public OptimizerExtension { optimize_function = HNSWIndexScanOptimizer::Optimize; } - static void TryOptimize(ClientContext &context, OptimizerExtensionInfo *info, unique_ptr &plan) { + static bool TryOptimize(ClientContext &context, OptimizerExtensionInfo *info, unique_ptr &plan) { // Look for a TopN operator auto &op = *plan; if (op.type != LogicalOperatorType::LOGICAL_TOP_N) { - return; + return false; } // Look for a expression that is a distance expression @@ -35,19 +37,19 @@ class HNSWIndexScanOptimizer : public OptimizerExtension { if (top_n.orders.size() != 1) { // We can only optimize if there is a single order by expression right now - return; + return false; } auto &order = top_n.orders[0]; if (order.type != OrderType::ASCENDING) { // We can only optimize if the order by expression is ascending - return; + return false; } if (order.expression->type != ExpressionType::BOUND_COLUMN_REF) { // The expression has to reference the child operator (a projection with the distance function) - return; + return false; } auto &bound_column_ref = order.expression->Cast(); @@ -55,19 +57,19 @@ class HNSWIndexScanOptimizer : public OptimizerExtension { auto &immediate_child = top_n.children[0]; if (immediate_child->type != LogicalOperatorType::LOGICAL_PROJECTION) { // The child has to be a projection - return; + return false; } auto &projection = immediate_child->Cast(); auto projection_index = bound_column_ref.binding.column_index; if (projection.expressions[projection_index]->type != ExpressionType::BOUND_FUNCTION) { // The expression has to be a function - return; + return false; } auto &bound_function = projection.expressions[projection_index]->Cast(); if (!HNSWIndex::IsDistanceFunction(bound_function.function.name)) { // We can only optimize if the order by expression is a distance function - return; + return false; } // Figure out the query vector @@ -78,7 +80,7 @@ class HNSWIndexScanOptimizer : public OptimizerExtension { target_value = bound_function.children[1]->Cast().value; } else { // We can only optimize if one of the children is a constant - return; + return false; } // TODO: We should check that the other argument to the distance function is a column reference @@ -87,7 +89,7 @@ class HNSWIndexScanOptimizer : public OptimizerExtension { auto value_type = target_value.type(); if (value_type.id() != LogicalTypeId::ARRAY) { // We can only optimize if the constant is an array - return; + return false; } auto array_size = ArrayType::GetSize(value_type); auto array_inner_type = ArrayType::GetChildType(value_type); @@ -96,7 +98,7 @@ class HNSWIndexScanOptimizer : public OptimizerExtension { bool ok = target_value.DefaultTryCastAs(LogicalType::ARRAY(LogicalType::FLOAT, array_size), true); if (!ok) { // We can only optimize if the array is of floats or we can cast it to floats - return; + return false; } } @@ -106,7 +108,7 @@ class HNSWIndexScanOptimizer : public OptimizerExtension { // TODO: Handle joins? if (child->children.size() != 1) { // Either 0 or more than 1 child - return; + return false; } child = child->children[0].get(); } @@ -114,7 +116,7 @@ class HNSWIndexScanOptimizer : public OptimizerExtension { auto &get = child->Cast(); // Check if the get is a table scan if (get.function.name != "seq_scan") { - return; + return false; } // We have a top-n operator on top of a table scan @@ -124,7 +126,7 @@ class HNSWIndexScanOptimizer : public OptimizerExtension { auto &table = *get.GetTable(); if (!table.IsDuckTable()) { // We can only replace the scan if the table is a duck table - return; + return false; } auto &duck_table = table.Cast(); @@ -166,7 +168,7 @@ class HNSWIndexScanOptimizer : public OptimizerExtension { if (!bind_data) { // No index found - return; + return false; } // Replace the scan with our custom index scan function @@ -177,17 +179,69 @@ class HNSWIndexScanOptimizer : public OptimizerExtension { get.estimated_cardinality = cardinality->estimated_cardinality; get.bind_data = std::move(bind_data); + + // Remove the distance function from the projection + // projection.expressions.erase(projection.expressions.begin() + static_cast(projection_index)); + //top_n.expressions + // Remove the TopN operator plan = std::move(top_n.children[0]); + return true; } - static void Optimize(ClientContext &context, OptimizerExtensionInfo *info, unique_ptr &plan) { - - TryOptimize(context, info, plan); + static bool OptimizeChildren(ClientContext &context, OptimizerExtensionInfo *info, unique_ptr &plan) { + auto ok = TryOptimize(context, info, plan); // Recursively optimize the children for (auto &child : plan->children) { - Optimize(context, info, child); + ok |= OptimizeChildren(context, info, child); + } + return ok; + } + + static void MergeProjections(unique_ptr &plan) { + if(plan->type == LogicalOperatorType::LOGICAL_PROJECTION) { + if(plan->children[0]->type == LogicalOperatorType::LOGICAL_PROJECTION) { + auto &child = plan->children[0]; + + if(child->children[0]->type == LogicalOperatorType::LOGICAL_GET && child->children[0]->Cast().function.name == "hnsw_index_scan") { + auto &parent_projection = plan->Cast(); + auto &child_projection = child->Cast(); + + column_binding_set_t referenced_bindings; + for(auto &expr : parent_projection.expressions) { + ExpressionIterator::EnumerateExpression(expr, [&](Expression& expr_ref) { + if(expr_ref.type == ExpressionType::BOUND_COLUMN_REF) { + auto &bound_column_ref = expr_ref.Cast(); + referenced_bindings.insert(bound_column_ref.binding); + } + }); + } + + auto child_bindings = child_projection.GetColumnBindings(); + for(idx_t i = 0; i < child_projection.expressions.size(); i++) { + auto &expr = child_projection.expressions[i]; + auto &outgoing_binding = child_bindings[i]; + + if(referenced_bindings.find(outgoing_binding) == referenced_bindings.end()) { + // The binding is not referenced + // We can remove this expression. But positionality matters so just replace with int. + expr = make_uniq_base(Value(LogicalType::TINYINT)); + } + } + return; + } + } + } + for(auto &child : plan->children) { + MergeProjections(child); + } + } + + static void Optimize(ClientContext &context, OptimizerExtensionInfo *info, unique_ptr &plan) { + auto did_use_hnsw_scan = OptimizeChildren(context, info, plan); + if(did_use_hnsw_scan) { + MergeProjections(plan); } } };