Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix index append, optimize removing unused distance projection if not referenced #11

Merged
merged 2 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,13 @@ release:
cmake $(GENERATOR) $(BUILD_FLAGS) $(CLIENT_FLAGS) -DCMAKE_BUILD_TYPE=Release -S ./duckdb/ -B build/release && \
cmake --build build/release --config Release


reldebug:
mkdir -p build/reldebug && \
cmake $(GENERATOR) $(BUILD_FLAGS) $(CLIENT_FLAGS) -DCMAKE_BUILD_TYPE=RelWithDebInfo -S ./duckdb/ -B build/reldebug && \
cmake --build build/reldebug --config RelWithDebInfo


##### Client build
JS_BUILD_FLAGS=-DBUILD_NODE=1 -DDUCKDB_EXTENSION_${EXTENSION_NAME}_SHOULD_LINK=0
PY_BUILD_FLAGS=-DBUILD_PYTHON=1 -DDUCKDB_EXTENSION_${EXTENSION_NAME}_SHOULD_LINK=0
Expand Down
12 changes: 10 additions & 2 deletions src/hnsw/hnsw_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -444,8 +444,16 @@ ErrorData HNSWIndex::Insert(IndexLock &lock, DataChunk &input, Vector &rowid_vec
return ErrorData {};
}

ErrorData HNSWIndex::Append(IndexLock &lock, DataChunk &entries, Vector &rowid_vec) {
Construct(entries, rowid_vec, unum::usearch::index_dense_t::any_thread());
ErrorData HNSWIndex::Append(IndexLock &lock, DataChunk &appended_data, Vector &row_identifiers) {
DataChunk expression_result;
expression_result.Initialize(Allocator::DefaultAllocator(), logical_types);

// first resolve the expressions for the index
ExecuteExpressions(appended_data, expression_result);

// now insert into the index
Construct(expression_result, row_identifiers, unum::usearch::index_dense_t::any_thread());

return ErrorData {};
}

Expand Down
100 changes: 77 additions & 23 deletions src/hnsw/hnsw_plan_index_scan.cpp
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
#include "duckdb/optimizer/optimizer_extension.hpp"
#include "duckdb/catalog/catalog_entry/duck_table_entry.hpp"
#include "duckdb/planner/expression/bound_function_expression.hpp"
#include "duckdb/optimizer/column_lifetime_analyzer.hpp"
#include "duckdb/optimizer/optimizer_extension.hpp"
#include "duckdb/planner/expression/bound_constant_expression.hpp"
#include "duckdb/planner/operator/logical_top_n.hpp"
#include "duckdb/planner/expression/bound_function_expression.hpp"
#include "duckdb/planner/operator/logical_get.hpp"
#include "duckdb/planner/operator/logical_projection.hpp"
#include "duckdb/planner/operator/logical_top_n.hpp"
#include "duckdb/storage/data_table.hpp"

#include "hnsw/hnsw.hpp"
#include "hnsw/hnsw_index.hpp"
#include "hnsw/hnsw_index_scan.hpp"
#include "duckdb/optimizer/remove_unused_columns.hpp"
#include "duckdb/planner/expression_iterator.hpp"

namespace duckdb {

Expand All @@ -22,52 +24,52 @@ class HNSWIndexScanOptimizer : public OptimizerExtension {
optimize_function = HNSWIndexScanOptimizer::Optimize;
}

static void TryOptimize(ClientContext &context, OptimizerExtensionInfo *info, unique_ptr<LogicalOperator> &plan) {
static bool TryOptimize(ClientContext &context, OptimizerExtensionInfo *info, unique_ptr<LogicalOperator> &plan) {
// Look for a TopN operator
auto &op = *plan;

if (op.type != LogicalOperatorType::LOGICAL_TOP_N) {
return;
return false;
}

// Look for a expression that is a distance expression
auto &top_n = op.Cast<LogicalTopN>();

if (top_n.orders.size() != 1) {
// We can only optimize if there is a single order by expression right now
return;
return false;
}

auto &order = top_n.orders[0];

if (order.type != OrderType::ASCENDING) {
// We can only optimize if the order by expression is ascending
return;
return false;
}

if (order.expression->type != ExpressionType::BOUND_COLUMN_REF) {
// The expression has to reference the child operator (a projection with the distance function)
return;
return false;
}
auto &bound_column_ref = order.expression->Cast<BoundColumnRefExpression>();

// find the expression that is referenced
auto &immediate_child = top_n.children[0];
if (immediate_child->type != LogicalOperatorType::LOGICAL_PROJECTION) {
// The child has to be a projection
return;
return false;
}
auto &projection = immediate_child->Cast<LogicalProjection>();
auto projection_index = bound_column_ref.binding.column_index;

if (projection.expressions[projection_index]->type != ExpressionType::BOUND_FUNCTION) {
// The expression has to be a function
return;
return false;
}
auto &bound_function = projection.expressions[projection_index]->Cast<BoundFunctionExpression>();
if (!HNSWIndex::IsDistanceFunction(bound_function.function.name)) {
// We can only optimize if the order by expression is a distance function
return;
return false;
}

// Figure out the query vector
Expand All @@ -78,7 +80,7 @@ class HNSWIndexScanOptimizer : public OptimizerExtension {
target_value = bound_function.children[1]->Cast<BoundConstantExpression>().value;
} else {
// We can only optimize if one of the children is a constant
return;
return false;
}

// TODO: We should check that the other argument to the distance function is a column reference
Expand All @@ -87,7 +89,7 @@ class HNSWIndexScanOptimizer : public OptimizerExtension {
auto value_type = target_value.type();
if (value_type.id() != LogicalTypeId::ARRAY) {
// We can only optimize if the constant is an array
return;
return false;
}
auto array_size = ArrayType::GetSize(value_type);
auto array_inner_type = ArrayType::GetChildType(value_type);
Expand All @@ -96,7 +98,7 @@ class HNSWIndexScanOptimizer : public OptimizerExtension {
bool ok = target_value.DefaultTryCastAs(LogicalType::ARRAY(LogicalType::FLOAT, array_size), true);
if (!ok) {
// We can only optimize if the array is of floats or we can cast it to floats
return;
return false;
}
}

Expand All @@ -106,15 +108,15 @@ class HNSWIndexScanOptimizer : public OptimizerExtension {
// TODO: Handle joins?
if (child->children.size() != 1) {
// Either 0 or more than 1 child
return;
return false;
}
child = child->children[0].get();
}

auto &get = child->Cast<LogicalGet>();
// Check if the get is a table scan
if (get.function.name != "seq_scan") {
return;
return false;
}

// We have a top-n operator on top of a table scan
Expand All @@ -124,7 +126,7 @@ class HNSWIndexScanOptimizer : public OptimizerExtension {
auto &table = *get.GetTable();
if (!table.IsDuckTable()) {
// We can only replace the scan if the table is a duck table
return;
return false;
}

auto &duck_table = table.Cast<DuckTableEntry>();
Expand Down Expand Up @@ -166,7 +168,7 @@ class HNSWIndexScanOptimizer : public OptimizerExtension {

if (!bind_data) {
// No index found
return;
return false;
}

// Replace the scan with our custom index scan function
Expand All @@ -177,17 +179,69 @@ class HNSWIndexScanOptimizer : public OptimizerExtension {
get.estimated_cardinality = cardinality->estimated_cardinality;
get.bind_data = std::move(bind_data);


// Remove the distance function from the projection
// projection.expressions.erase(projection.expressions.begin() + static_cast<ptrdiff_t>(projection_index));
//top_n.expressions

// Remove the TopN operator
plan = std::move(top_n.children[0]);
return true;
}

static void Optimize(ClientContext &context, OptimizerExtensionInfo *info, unique_ptr<LogicalOperator> &plan) {

TryOptimize(context, info, plan);
static bool OptimizeChildren(ClientContext &context, OptimizerExtensionInfo *info, unique_ptr<LogicalOperator> &plan) {

auto ok = TryOptimize(context, info, plan);
// Recursively optimize the children
for (auto &child : plan->children) {
Optimize(context, info, child);
ok |= OptimizeChildren(context, info, child);
}
return ok;
}

static void MergeProjections(unique_ptr<LogicalOperator> &plan) {
if(plan->type == LogicalOperatorType::LOGICAL_PROJECTION) {
if(plan->children[0]->type == LogicalOperatorType::LOGICAL_PROJECTION) {
auto &child = plan->children[0];

if(child->children[0]->type == LogicalOperatorType::LOGICAL_GET && child->children[0]->Cast<LogicalGet>().function.name == "hnsw_index_scan") {
auto &parent_projection = plan->Cast<LogicalProjection>();
auto &child_projection = child->Cast<LogicalProjection>();

column_binding_set_t referenced_bindings;
for(auto &expr : parent_projection.expressions) {
ExpressionIterator::EnumerateExpression(expr, [&](Expression& expr_ref) {
if(expr_ref.type == ExpressionType::BOUND_COLUMN_REF) {
auto &bound_column_ref = expr_ref.Cast<BoundColumnRefExpression>();
referenced_bindings.insert(bound_column_ref.binding);
}
});
}

auto child_bindings = child_projection.GetColumnBindings();
for(idx_t i = 0; i < child_projection.expressions.size(); i++) {
auto &expr = child_projection.expressions[i];
auto &outgoing_binding = child_bindings[i];

if(referenced_bindings.find(outgoing_binding) == referenced_bindings.end()) {
// The binding is not referenced
// We can remove this expression. But positionality matters so just replace with int.
expr = make_uniq_base<Expression, BoundConstantExpression>(Value(LogicalType::TINYINT));
}
}
return;
}
}
}
for(auto &child : plan->children) {
MergeProjections(child);
}
}

static void Optimize(ClientContext &context, OptimizerExtensionInfo *info, unique_ptr<LogicalOperator> &plan) {
auto did_use_hnsw_scan = OptimizeChildren(context, info, plan);
if(did_use_hnsw_scan) {
MergeProjections(plan);
}
}
};
Expand Down
Loading