Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
samansmink committed Nov 4, 2024
1 parent b7333c0 commit b105f2c
Show file tree
Hide file tree
Showing 7 changed files with 419 additions and 31 deletions.
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ set(EXTENSION_SOURCES
src/delta_extension.cpp
src/delta_functions.cpp
src/delta_utils.cpp
src/functions/expression_functions.cpp
src/functions/delta_scan.cpp)

### Custom config
Expand Down Expand Up @@ -99,7 +100,7 @@ ExternalProject_Add(
GIT_REPOSITORY "https://github.com/delta-incubator/delta-kernel-rs"
# WARNING: the FFI headers are currently pinned due to the C linkage issue of the c++ headers. Currently, when bumping
# the kernel version, the produced header in ./src/include/delta_kernel_ffi.hpp should be also bumped, applying the fix
GIT_TAG v0.4.0
GIT_TAG main
# Prints the env variables passed to the cargo build to the terminal, useful in debugging because passing them
# through CMake is an error-prone mess
CONFIGURE_COMMAND ${CMAKE_COMMAND} -E env ${RUST_UNSET_ENV_VARS} ${RUST_ENV_VARS} env
Expand Down
6 changes: 5 additions & 1 deletion src/delta_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,14 @@
namespace duckdb {

static void LoadInternal(DatabaseInstance &instance) {
// Load functions
// Load Table functions
for (const auto &function : DeltaFunctions::GetTableFunctions(instance)) {
ExtensionUtil::RegisterFunction(instance, function);
}
// Load Scalar functions
for (const auto &function : DeltaFunctions::GetScalarFunctions(instance)) {
ExtensionUtil::RegisterFunction(instance, function);
}
}

void DeltaExtension::Load(DuckDB &db) {
Expand Down
9 changes: 9 additions & 0 deletions src/delta_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,13 @@ vector<TableFunctionSet> DeltaFunctions::GetTableFunctions(DatabaseInstance &ins
return functions;
}

vector<ScalarFunctionSet> DeltaFunctions::GetScalarFunctions(DatabaseInstance &instance) {
vector<ScalarFunctionSet> functions;

functions.push_back(GetExpressionFunction(instance));

return functions;
}


};
127 changes: 127 additions & 0 deletions src/delta_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,133 @@

namespace duckdb {

unique_ptr<vector<unique_ptr<BaseExpression>>> ExpressionVisitor::VisitKernelExpression(const ffi::Handle<ffi::SharedExpression> *expression) {
ExpressionVisitor state;
ffi::EngineExpressionVisitor visitor;

visitor.data = &state;
visitor.make_field_list = (uintptr_t (*)(void*, uintptr_t)) &MakeFieldList;

// Templated primitive functions
visitor.visit_literal_bool = VisitPrimitiveLiteral<bool, Value::BOOLEAN>();
visitor.visit_literal_byte = VisitPrimitiveLiteral<int8_t, Value::TINYINT>();
visitor.visit_literal_short = VisitPrimitiveLiteral<int16_t, Value::SMALLINT>();
visitor.visit_literal_int = VisitPrimitiveLiteral<int32_t, Value::INTEGER>();
visitor.visit_literal_long = VisitPrimitiveLiteral<int64_t, Value::BIGINT>();
visitor.visit_literal_float = VisitPrimitiveLiteral<float, Value::FLOAT>();
visitor.visit_literal_double = VisitPrimitiveLiteral<double, Value::DOUBLE>();

// Custom Implementations
visitor.visit_literal_timestamp = &VisitTimestampLiteral;
visitor.visit_literal_timestamp_ntz = &VisitTimestampNtzLiteral;
visitor.visit_literal_date = &VisitDateLiteral;

visitor.visit_literal_string = &VisitStringLiteral;

visitor.visit_literal_binary = &VisitBinaryLiteral;
visitor.visit_literal_null = &VisitNullLiteral;
visitor.visit_literal_array = &VisitArrayLiteral;

visitor.visit_and = VisitBinaryExpression<ExpressionType::CONJUNCTION_AND, ConjunctionExpression>();
visitor.visit_or = VisitBinaryExpression<ExpressionType::CONJUNCTION_OR, ConjunctionExpression>();

visitor.visit_lt = VisitBinaryExpression<ExpressionType::COMPARE_LESSTHAN, ComparisonExpression>();
visitor.visit_le = VisitBinaryExpression<ExpressionType::COMPARE_LESSTHANOREQUALTO, ComparisonExpression>();
visitor.visit_gt = VisitBinaryExpression<ExpressionType::COMPARE_GREATERTHAN, ComparisonExpression>();
visitor.visit_ge = VisitBinaryExpression<ExpressionType::COMPARE_GREATERTHANOREQUALTO, ComparisonExpression>();

visitor.visit_ne = VisitBinaryExpression<ExpressionType::COMPARE_NOTEQUAL, ComparisonExpression>();
visitor.visit_distinct = VisitBinaryExpression<ExpressionType::COMPARE_DISTINCT_FROM, ComparisonExpression>();

visitor.visit_in = VisitBinaryExpression<ExpressionType::COMPARE_IN, ComparisonExpression>();
visitor.visit_not_in = VisitBinaryExpression<ExpressionType::COMPARE_NOT_IN, ComparisonExpression>();

// TODO fix these
visitor.visit_add = VisitBinaryExpression<ExpressionType::COMPARE_NOT_IN, ComparisonExpression>();
visitor.visit_minus = VisitBinaryExpression<ExpressionType::COMPARE_NOT_IN, ComparisonExpression>();
visitor.visit_multiply = VisitBinaryExpression<ExpressionType::COMPARE_NOT_IN, ComparisonExpression>();
visitor.visit_divide = VisitBinaryExpression<ExpressionType::COMPARE_NOT_IN, ComparisonExpression>();

visitor.visit_column = &VisitColumnExpression;
visitor.visit_struct_expr = &VisitStructExpression;

visitor.visit_literal_struct = &Visit;

uintptr_t result = visit_expression(expression, &visitor);
return state.TakeFieldList(result);
}

void ExpressionVisitor::VisitTimestampLiteral(void* state, uintptr_t sibling_list_id, int64_t value) {
auto expression = make_uniq<ConstantExpression>(Value::TIMESTAMPTZ(static_cast<timestamp_t>(value)));
static_cast<ExpressionVisitor*>(state)->AppendToList(sibling_list_id, std::move(expression));
}

void ExpressionVisitor::VisitTimestampNtzLiteral(void* state, uintptr_t sibling_list_id, int64_t value) {
auto expression = make_uniq<ConstantExpression>(Value::TIMESTAMP(static_cast<timestamp_t>(value)));
static_cast<ExpressionVisitor*>(state)->AppendToList(sibling_list_id, std::move(expression));
}

void ExpressionVisitor::VisitDateLiteral(void* state, uintptr_t sibling_list_id, int32_t value) {
auto expression = make_uniq<ConstantExpression>(Value::DATE(static_cast<date_t>(value)));
static_cast<ExpressionVisitor*>(state)->AppendToList(sibling_list_id, std::move(expression));
}

void ExpressionVisitor::VisitStringLiteral(void* state, uintptr_t sibling_list_id, ffi::KernelStringSlice value) {
auto expression = make_uniq<ConstantExpression>(Value(string(value.ptr, value.len)));
static_cast<ExpressionVisitor*>(state)->AppendToList(sibling_list_id, std::move(expression));
}
void ExpressionVisitor::VisitBinaryLiteral(void* state, uintptr_t sibling_list_id, const uint8_t *buffer, uintptr_t len) {
auto expression = make_uniq<ConstantExpression>(Value::BLOB(buffer, len));
static_cast<ExpressionVisitor*>(state)->AppendToList(sibling_list_id, std::move(expression));
}
void ExpressionVisitor::VisitNullLiteral(void* state, uintptr_t sibling_list_id) {
auto expression = make_uniq<ConstantExpression>(Value());
static_cast<ExpressionVisitor*>(state)->AppendToList(sibling_list_id, std::move(expression));
}
void ExpressionVisitor::VisitArrayLiteral(void* state, uintptr_t sibling_list_id, uintptr_t child_id) {
throw NotImplementedException("ExpressionVisitor::VisitArrayLiteral");
}
// TODO: same as string
void ExpressionVisitor::VisitColumnExpression(void *state, uintptr_t sibling_list_id, ffi::KernelStringSlice name) {
auto expression = make_uniq<ColumnRefExpression>(string(name.ptr, name.len));
static_cast<ExpressionVisitor*>(state)->AppendToList(sibling_list_id, std::move(expression));
}
void ExpressionVisitor::VisitStructExpression(void *state, uintptr_t sibling_list_id, uintptr_t child_list_id) {
throw NotImplementedException("ExpressionVisitor::VisitStructExpression");
}

uintptr_t ExpressionVisitor::MakeFieldList(ExpressionVisitor* state, uintptr_t capacity_hint) {
return state->MakeFieldListImpl(capacity_hint);
}
uintptr_t ExpressionVisitor::MakeFieldListImpl(uintptr_t capacity_hint) {
uintptr_t id = next_id++;
auto list = make_uniq<FieldList>();
if (capacity_hint > 0) {
list->reserve(capacity_hint);
}
inflight_lists.emplace(id, std::move(list));
return id;
}

void ExpressionVisitor::AppendToList(uintptr_t id, unique_ptr<BaseExpression> child) {
auto it = inflight_lists.find(id);
if (it == inflight_lists.end()) {
throw InternalException("ExpressionVisitor::AppendToList");
}

it->second->emplace_back(std::move(child));
}

unique_ptr<ExpressionVisitor::FieldList> ExpressionVisitor::TakeFieldList(uintptr_t id) {
auto it = inflight_lists.find(id);
if (it == inflight_lists.end()) {
throw InternalException("SchemaVisitor::TakeFieldList");
}
auto rval = std::move(it->second);
inflight_lists.erase(it);
return rval;
}

unique_ptr<SchemaVisitor::FieldList> SchemaVisitor::VisitSnapshotSchema(ffi::SharedSnapshot* snapshot) {
SchemaVisitor state;
ffi::EngineSchemaVisitor visitor;
Expand Down
5 changes: 5 additions & 0 deletions src/include/delta_functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,13 @@ namespace duckdb {
class DeltaFunctions {
public:
static vector<TableFunctionSet> GetTableFunctions(DatabaseInstance &instance);
static vector<ScalarFunctionSet> GetScalarFunctions(DatabaseInstance &instance);

private:
//! Table Functions
static TableFunctionSet GetDeltaScanFunction(DatabaseInstance &instance);

//! Scalar Functions
static ScalarFunctionSet GetExpressionFunction(DatabaseInstance &instance);
};
} // namespace duckdb
Loading

0 comments on commit b105f2c

Please sign in to comment.