diff --git a/chsql/CMakeLists.txt b/chsql/CMakeLists.txt index 06de4a9..f2b1f46 100644 --- a/chsql/CMakeLists.txt +++ b/chsql/CMakeLists.txt @@ -21,7 +21,7 @@ include_directories( ../duckdb/third_party/mbedtls ../duckdb/third_party/mbedtls/include ../duckdb/third_party/brotli/include) -set(EXTENSION_SOURCES src/chsql_extension.cpp) +set(EXTENSION_SOURCES src/chsql_extension.cpp src/duck_flock.cpp) build_static_extension(${TARGET_NAME} ${EXTENSION_SOURCES}) build_loadable_extension(${TARGET_NAME} " " ${EXTENSION_SOURCES}) # Link OpenSSL in both the static library as the loadable extension diff --git a/chsql/src/chsql_extension.cpp b/chsql/src/chsql_extension.cpp index 0bec6db..5dd8fe7 100644 --- a/chsql/src/chsql_extension.cpp +++ b/chsql/src/chsql_extension.cpp @@ -216,7 +216,7 @@ static void LoadInternal(DatabaseInstance &instance) { ExtensionUtil::RegisterFunction(instance, chsql_openssl_version_scalar_function); // Macros - for (idx_t index = 0; chsql_macros[index].name != nullptr; index++) { + for (idx_t index = 0; chsql_macros[index].name != nullptr; index++) { auto info = DefaultFunctionGenerator::CreateInternalMacroInfo(chsql_macros[index]); ExtensionUtil::RegisterFunction(instance, *info); } @@ -226,6 +226,8 @@ static void LoadInternal(DatabaseInstance &instance) { ExtensionUtil::RegisterFunction(instance, *table_info); } ExtensionUtil::RegisterFunction(instance, ReadParquetOrderedFunction()); + // Flock + ExtensionUtil::RegisterFunction(instance, DuckFlockTableFunction()); } void ChsqlExtension::Load(DuckDB &db) { diff --git a/chsql/src/duck_flock.cpp b/chsql/src/duck_flock.cpp new file mode 100644 index 0000000..90e3588 --- /dev/null +++ b/chsql/src/duck_flock.cpp @@ -0,0 +1,136 @@ +#ifndef DUCK_FLOCK_H +#define DUCK_FLOCK_H +#include "chsql_extension.hpp" + +namespace duckdb { + struct DuckFlockData : FunctionData { + vector> conn; + vector> results; + unique_ptr Copy() const override { + throw std::runtime_error("not implemented"); + } + bool Equals(const FunctionData &other) const override { + throw std::runtime_error("not implemented"); + }; + }; + + unique_ptr DuckFlockBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + auto data = make_uniq(); + + // Check for NULL input parameters + if (input.inputs.empty() || input.inputs.size() < 2) { + throw std::runtime_error("url_flock: missing required parameters"); + } + if (input.inputs[0].IsNull() || input.inputs[1].IsNull()) { + throw std::runtime_error("url_flock: NULL parameters are not allowed"); + } + + auto strQuery = input.inputs[0].GetValue(); + if (strQuery.empty()) { + throw std::runtime_error("url_flock: empty query string"); + } + + auto &raw_flock = ListValue::GetChildren(input.inputs[1]); + if (raw_flock.empty()) { + throw std::runtime_error("url_flock: empty flock list"); + } + + bool has_valid_result = false; + // Process each connection + for (auto &duck : raw_flock) { + if (duck.IsNull() || duck.ToString().empty()) { + continue; + } + + try { + auto conn = make_uniq(*context.db); + if (!conn) { + continue; + } + + auto settingResult = conn->Query("SET autoload_known_extensions=1;SET autoinstall_known_extensions=1;"); + if (settingResult->HasError()) { + continue; + } + + auto req = conn->Prepare("SELECT * FROM read_json($2 || '/?default_format=JSONEachRow&query=' || url_encode($1::VARCHAR))"); + if (req->HasError()) { + continue; + } + + auto queryResult = req->Execute(strQuery.c_str(), duck.ToString()); + if (!queryResult || queryResult->HasError()) { + continue; + } + + // Store the first valid result's types and names + if (!has_valid_result) { + return_types.clear(); + copy(queryResult->types.begin(), queryResult->types.end(), back_inserter(return_types)); + names.clear(); + copy(queryResult->names.begin(), queryResult->names.end(), back_inserter(names)); + + if (return_types.empty()) { + throw std::runtime_error("url_flock: query must return at least one column"); + } + has_valid_result = true; + } + + data->conn.push_back(std::move(conn)); + data->results.push_back(std::move(queryResult)); + } catch (const std::exception &e) { + continue; + } + } + + // Verify we have at least one valid result + if (!has_valid_result || data->results.empty()) { + throw std::runtime_error("url_flock: invalid or no results"); + } + + return std::move(data); + } + + void DuckFlockImplementation(ClientContext &context, TableFunctionInput &data_p, + DataChunk &output) { + auto &data = data_p.bind_data->Cast(); + + if (data.results.empty()) { + return; + } + + for (const auto &res : data.results) { + if (!res) { + continue; + } + + ErrorData error_data; + unique_ptr data_chunk = make_uniq(); + + try { + if (res->TryFetch(data_chunk, error_data)) { + if (data_chunk && !data_chunk->size() == 0) { + output.Append(*data_chunk); + return; + } + } + } catch (...) { + continue; + } + } + } + + TableFunction DuckFlockTableFunction() { + TableFunction f( + "url_flock", + {LogicalType::VARCHAR, LogicalType::LIST(LogicalType::VARCHAR)}, + DuckFlockImplementation, + DuckFlockBind, + nullptr, + nullptr + ); + return f; + } +} +#endif diff --git a/chsql/src/include/chsql_extension.hpp b/chsql/src/include/chsql_extension.hpp index 33e2c6f..b1c7a5e 100644 --- a/chsql/src/include/chsql_extension.hpp +++ b/chsql/src/include/chsql_extension.hpp @@ -12,4 +12,7 @@ class ChsqlExtension : public Extension { }; duckdb::TableFunction ReadParquetOrderedFunction(); static void RegisterSillyBTreeStore(DatabaseInstance &instance); + +TableFunction DuckFlockTableFunction(); + } // namespace duckdb