diff --git a/src/include/storage/uc_catalog.hpp b/src/include/storage/uc_catalog.hpp index 2fdbb22..1f333d1 100644 --- a/src/include/storage/uc_catalog.hpp +++ b/src/include/storage/uc_catalog.hpp @@ -17,8 +17,8 @@ namespace duckdb { class UCSchemaEntry; struct UCCredentials { - string endpoint; - string token; + string endpoint; + string token; // Not really part of the credentials, but required to query s3 tables string aws_region; @@ -26,19 +26,20 @@ struct UCCredentials { class UCClearCacheFunction : public TableFunction { public: - UCClearCacheFunction(); + UCClearCacheFunction(); - static void ClearCacheOnSetting(ClientContext &context, SetScope scope, Value ¶meter); + static void ClearCacheOnSetting(ClientContext &context, SetScope scope, Value ¶meter); }; class UCCatalog : public Catalog { public: - explicit UCCatalog(AttachedDatabase &db_p, const string &internal_name, AccessMode access_mode, UCCredentials credentials); + explicit UCCatalog(AttachedDatabase &db_p, const string &internal_name, AccessMode access_mode, + UCCredentials credentials); ~UCCatalog(); string internal_name; AccessMode access_mode; - UCCredentials credentials; + UCCredentials credentials; public: void Initialize(bool load_builtin) override; diff --git a/src/include/storage/uc_schema_entry.hpp b/src/include/storage/uc_schema_entry.hpp index c68ef32..b9d11f1 100644 --- a/src/include/storage/uc_schema_entry.hpp +++ b/src/include/storage/uc_schema_entry.hpp @@ -18,9 +18,10 @@ class UCTransaction; class UCSchemaEntry : public SchemaCatalogEntry { public: UCSchemaEntry(Catalog &catalog, CreateSchemaInfo &info); - ~UCSchemaEntry() override; + ~UCSchemaEntry() override; + + unique_ptr schema_data; - unique_ptr schema_data; public: optional_ptr CreateTable(CatalogTransaction transaction, BoundCreateTableInfo &info) override; optional_ptr CreateFunction(CatalogTransaction transaction, CreateFunctionInfo &info) override; diff --git a/src/include/storage/uc_table_entry.hpp b/src/include/storage/uc_table_entry.hpp index 7174405..50b2528 100644 --- a/src/include/storage/uc_table_entry.hpp +++ b/src/include/storage/uc_table_entry.hpp @@ -37,7 +37,8 @@ class UCTableEntry : public TableCatalogEntry { UCTableEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateTableInfo &info); UCTableEntry(Catalog &catalog, SchemaCatalogEntry &schema, UCTableInfo &info); - unique_ptr table_data; + unique_ptr table_data; + public: unique_ptr GetStatistics(ClientContext &context, column_t column_id) override; diff --git a/src/include/storage/uc_table_set.hpp b/src/include/storage/uc_table_set.hpp index 9788d3c..fd8ec0b 100644 --- a/src/include/storage/uc_table_set.hpp +++ b/src/include/storage/uc_table_set.hpp @@ -24,7 +24,7 @@ class UCTableSet : public UCInSchemaSet { optional_ptr CreateTable(ClientContext &context, BoundCreateTableInfo &info); static unique_ptr GetTableInfo(ClientContext &context, UCSchemaEntry &schema, - const string &table_name); + const string &table_name); optional_ptr RefreshTable(ClientContext &context, const string &table_name); void AlterTable(ClientContext &context, AlterTableInfo &info); @@ -37,8 +37,7 @@ class UCTableSet : public UCInSchemaSet { void AlterTable(ClientContext &context, AddColumnInfo &info); void AlterTable(ClientContext &context, RemoveColumnInfo &info); - static void AddColumn(ClientContext &context, UCResult &result, UCTableInfo &table_info, - idx_t column_offset = 0); + static void AddColumn(ClientContext &context, UCResult &result, UCTableInfo &table_info, idx_t column_offset = 0); }; } // namespace duckdb diff --git a/src/include/storage/uc_transaction.hpp b/src/include/storage/uc_transaction.hpp index a3bb951..86a1423 100644 --- a/src/include/storage/uc_transaction.hpp +++ b/src/include/storage/uc_transaction.hpp @@ -26,15 +26,15 @@ class UCTransaction : public Transaction { void Commit(); void Rollback(); -// UCConnection &GetConnection(); -// unique_ptr Query(const string &query); + // UCConnection &GetConnection(); + // unique_ptr Query(const string &query); static UCTransaction &Get(ClientContext &context, Catalog &catalog); AccessMode GetAccessMode() const { return access_mode; } private: -// UCConnection connection; + // UCConnection connection; UCTransactionState transaction_state; AccessMode access_mode; }; diff --git a/src/include/uc_api.hpp b/src/include/uc_api.hpp index 57c5d79..4be9dcd 100644 --- a/src/include/uc_api.hpp +++ b/src/include/uc_api.hpp @@ -14,46 +14,46 @@ namespace duckdb { struct UCCredentials; struct UCAPIColumnDefinition { - string name; - string type_text; - idx_t precision; - idx_t scale; - idx_t position; + string name; + string type_text; + idx_t precision; + idx_t scale; + idx_t position; }; struct UCAPITable { - string table_id; + string table_id; - string name; - string catalog_name; - string schema_name; - string table_type; - string data_source_format; + string name; + string catalog_name; + string schema_name; + string table_type; + string data_source_format; - string storage_location; - string delta_last_commit_timestamp; - string delta_last_update_version; + string storage_location; + string delta_last_commit_timestamp; + string delta_last_update_version; - vector columns; + vector columns; }; struct UCAPISchema { - string schema_name; - string catalog_name; + string schema_name; + string catalog_name; }; struct UCAPITableCredentials { - string key_id; - string secret; - string session_token; + string key_id; + string secret; + string session_token; }; class UCAPI { public: - static UCAPITableCredentials GetTableCredentials(const string &table_id, UCCredentials credentials); - static vector GetCatalogs(const string &catalog, UCCredentials credentials); - static vector GetTables(const string &catalog, const string &schema, UCCredentials credentials); - static vector GetSchemas(const string &catalog, UCCredentials credentials); - static vector GetTablesInSchema(const string &catalog, const string &schema, UCCredentials credentials); + static UCAPITableCredentials GetTableCredentials(const string &table_id, UCCredentials credentials); + static vector GetCatalogs(const string &catalog, UCCredentials credentials); + static vector GetTables(const string &catalog, const string &schema, UCCredentials credentials); + static vector GetSchemas(const string &catalog, UCCredentials credentials); + static vector GetTablesInSchema(const string &catalog, const string &schema, UCCredentials credentials); }; } // namespace duckdb diff --git a/src/include/uc_catalog_extension.hpp b/src/include/uc_catalog_extension.hpp index f7cd170..cbd0cfe 100644 --- a/src/include/uc_catalog_extension.hpp +++ b/src/include/uc_catalog_extension.hpp @@ -6,9 +6,9 @@ namespace duckdb { class UcCatalogExtension : public Extension { public: - void Load(DuckDB &db) override; - std::string Name() override; - std::string Version() const override; + void Load(DuckDB &db) override; + std::string Name() override; + std::string Version() const override; }; } // namespace duckdb diff --git a/src/storage/uc_catalog.cpp b/src/storage/uc_catalog.cpp index 96e1d20..98bd644 100644 --- a/src/storage/uc_catalog.cpp +++ b/src/storage/uc_catalog.cpp @@ -8,14 +8,15 @@ namespace duckdb { -UCCatalog::UCCatalog(AttachedDatabase &db_p, const string &internal_name, AccessMode access_mode, UCCredentials credentials) - : Catalog(db_p), internal_name(internal_name),access_mode(access_mode), credentials(std::move(credentials)), schemas(*this) { +UCCatalog::UCCatalog(AttachedDatabase &db_p, const string &internal_name, AccessMode access_mode, + UCCredentials credentials) + : Catalog(db_p), internal_name(internal_name), access_mode(access_mode), credentials(std::move(credentials)), + schemas(*this) { } UCCatalog::~UCCatalog() = default; void UCCatalog::Initialize(bool load_builtin) { - } optional_ptr UCCatalog::CreateSchema(CatalogTransaction transaction, CreateSchemaInfo &info) { @@ -39,8 +40,7 @@ void UCCatalog::ScanSchemas(ClientContext &context, std::function UCCatalog::GetSchema(CatalogTransaction transaction, const string &schema_name, - OnEntryNotFound if_not_found, - QueryErrorContext error_context) { + OnEntryNotFound if_not_found, QueryErrorContext error_context) { if (schema_name == DEFAULT_SCHEMA) { if (default_schema.empty()) { throw InvalidInputException("Attempting to fetch the default schema - but no database was " @@ -77,24 +77,24 @@ void UCCatalog::ClearCache() { } unique_ptr UCCatalog::PlanInsert(ClientContext &context, LogicalInsert &op, - unique_ptr plan) { - throw NotImplementedException("UCCatalog PlanInsert"); + unique_ptr plan) { + throw NotImplementedException("UCCatalog PlanInsert"); } unique_ptr UCCatalog::PlanCreateTableAs(ClientContext &context, LogicalCreateTable &op, - unique_ptr plan) { - throw NotImplementedException("UCCatalog PlanCreateTableAs"); + unique_ptr plan) { + throw NotImplementedException("UCCatalog PlanCreateTableAs"); } unique_ptr UCCatalog::PlanDelete(ClientContext &context, LogicalDelete &op, - unique_ptr plan) { - throw NotImplementedException("UCCatalog PlanDelete"); + unique_ptr plan) { + throw NotImplementedException("UCCatalog PlanDelete"); } unique_ptr UCCatalog::PlanUpdate(ClientContext &context, LogicalUpdate &op, - unique_ptr plan) { - throw NotImplementedException("UCCatalog PlanUpdate"); + unique_ptr plan) { + throw NotImplementedException("UCCatalog PlanUpdate"); } unique_ptr UCCatalog::BindCreateIndex(Binder &binder, CreateStatement &stmt, TableCatalogEntry &table, - unique_ptr plan) { - throw NotImplementedException("UCCatalog BindCreateIndex"); + unique_ptr plan) { + throw NotImplementedException("UCCatalog BindCreateIndex"); } } // namespace duckdb diff --git a/src/storage/uc_catalog_set.cpp b/src/storage/uc_catalog_set.cpp index 597f0cc..6451087 100644 --- a/src/storage/uc_catalog_set.cpp +++ b/src/storage/uc_catalog_set.cpp @@ -22,7 +22,7 @@ optional_ptr UCCatalogSet::GetEntry(ClientContext &context, const } void UCCatalogSet::DropEntry(ClientContext &context, DropInfo &info) { - throw NotImplementedException("UCCatalogSet::DropEntry"); + throw NotImplementedException("UCCatalogSet::DropEntry"); } void UCCatalogSet::EraseEntryInternal(const string &name) { diff --git a/src/storage/uc_clear_cache.cpp b/src/storage/uc_clear_cache.cpp index c7e91ce..89ccddb 100644 --- a/src/storage/uc_clear_cache.cpp +++ b/src/storage/uc_clear_cache.cpp @@ -45,7 +45,6 @@ void UCClearCacheFunction::ClearCacheOnSetting(ClientContext &context, SetScope ClearUCCaches(context); } -UCClearCacheFunction::UCClearCacheFunction() - : TableFunction("uc_clear_cache", {}, ClearCacheFunction, ClearCacheBind) { +UCClearCacheFunction::UCClearCacheFunction() : TableFunction("uc_clear_cache", {}, ClearCacheFunction, ClearCacheBind) { } } // namespace duckdb diff --git a/src/storage/uc_schema_entry.cpp b/src/storage/uc_schema_entry.cpp index 702a3ab..709c7b7 100644 --- a/src/storage/uc_schema_entry.cpp +++ b/src/storage/uc_schema_entry.cpp @@ -17,7 +17,7 @@ UCSchemaEntry::UCSchemaEntry(Catalog &catalog, CreateSchemaInfo &info) : SchemaCatalogEntry(catalog, info), tables(*this) { } -UCSchemaEntry::~UCSchemaEntry(){ +UCSchemaEntry::~UCSchemaEntry() { } UCTransaction &GetUCTransaction(CatalogTransaction transaction) { @@ -51,7 +51,7 @@ void UCUnqualifyColumnRef(ParsedExpression &expr) { } optional_ptr UCSchemaEntry::CreateIndex(CatalogTransaction transaction, CreateIndexInfo &info, - TableCatalogEntry &table) { + TableCatalogEntry &table) { throw NotImplementedException("CreateIndex"); } @@ -71,11 +71,11 @@ optional_ptr UCSchemaEntry::CreateView(CatalogTransaction transact if (info.on_conflict == OnCreateConflict::IGNORE_ON_CONFLICT) { return current_entry; } - throw NotImplementedException("REPLACE ON CONFLICT in CreateView"); + throw NotImplementedException("REPLACE ON CONFLICT in CreateView"); } } auto &uc_transaction = GetUCTransaction(transaction); -// uc_transaction.Query(GetUCCreateView(info)); + // uc_transaction.Query(GetUCCreateView(info)); return tables.RefreshTable(transaction.GetContext(), info.view_name); } @@ -88,22 +88,21 @@ optional_ptr UCSchemaEntry::CreateSequence(CatalogTransaction tran } optional_ptr UCSchemaEntry::CreateTableFunction(CatalogTransaction transaction, - CreateTableFunctionInfo &info) { + CreateTableFunctionInfo &info) { throw BinderException("UC databases do not support creating table functions"); } optional_ptr UCSchemaEntry::CreateCopyFunction(CatalogTransaction transaction, - CreateCopyFunctionInfo &info) { + CreateCopyFunctionInfo &info) { throw BinderException("UC databases do not support creating copy functions"); } optional_ptr UCSchemaEntry::CreatePragmaFunction(CatalogTransaction transaction, - CreatePragmaFunctionInfo &info) { + CreatePragmaFunctionInfo &info) { throw BinderException("UC databases do not support creating pragma functions"); } -optional_ptr UCSchemaEntry::CreateCollation(CatalogTransaction transaction, - CreateCollationInfo &info) { +optional_ptr UCSchemaEntry::CreateCollation(CatalogTransaction transaction, CreateCollationInfo &info) { throw BinderException("UC databases do not support creating collations"); } @@ -127,7 +126,7 @@ bool CatalogTypeIsSupported(CatalogType type) { } void UCSchemaEntry::Scan(ClientContext &context, CatalogType type, - const std::function &callback) { + const std::function &callback) { if (!CatalogTypeIsSupported(type)) { return; } @@ -142,7 +141,7 @@ void UCSchemaEntry::DropEntry(ClientContext &context, DropInfo &info) { } optional_ptr UCSchemaEntry::GetEntry(CatalogTransaction transaction, CatalogType type, - const string &name) { + const string &name) { if (!CatalogTypeIsSupported(type)) { return nullptr; } diff --git a/src/storage/uc_schema_set.cpp b/src/storage/uc_schema_set.cpp index 251067b..9138171 100644 --- a/src/storage/uc_schema_set.cpp +++ b/src/storage/uc_schema_set.cpp @@ -10,25 +10,25 @@ namespace duckdb { UCSchemaSet::UCSchemaSet(Catalog &catalog) : UCCatalogSet(catalog) { } -static bool IsInternalTable(const string & catalog, const string &schema) { - if (schema == "information_schema") { - return true; - } - return false; +static bool IsInternalTable(const string &catalog, const string &schema) { + if (schema == "information_schema") { + return true; + } + return false; } void UCSchemaSet::LoadEntries(ClientContext &context) { - auto &uc_catalog = catalog.Cast(); - auto tables = UCAPI::GetSchemas(catalog.GetName(), uc_catalog.credentials); + auto &uc_catalog = catalog.Cast(); + auto tables = UCAPI::GetSchemas(catalog.GetName(), uc_catalog.credentials); - for (const auto& schema: tables) { - CreateSchemaInfo info; - info.schema = schema.schema_name; - info.internal = IsInternalTable(schema.catalog_name, schema.schema_name); - auto schema_entry = make_uniq(catalog, info); - schema_entry->schema_data = make_uniq(schema); - CreateEntry(std::move(schema_entry)); - } + for (const auto &schema : tables) { + CreateSchemaInfo info; + info.schema = schema.schema_name; + info.internal = IsInternalTable(schema.catalog_name, schema.schema_name); + auto schema_entry = make_uniq(catalog, info); + schema_entry->schema_data = make_uniq(schema); + CreateEntry(std::move(schema_entry)); + } } optional_ptr UCSchemaSet::CreateSchema(ClientContext &context, CreateSchemaInfo &info) { diff --git a/src/storage/uc_table_entry.cpp b/src/storage/uc_table_entry.cpp index 744e697..c333825 100644 --- a/src/storage/uc_table_entry.cpp +++ b/src/storage/uc_table_entry.cpp @@ -28,55 +28,58 @@ unique_ptr UCTableEntry::GetStatistics(ClientContext &context, c return nullptr; } -void UCTableEntry::BindUpdateConstraints(Binder &binder, LogicalGet &, LogicalProjection &, LogicalUpdate &, ClientContext &) { - throw NotImplementedException("BindUpdateConstraints"); +void UCTableEntry::BindUpdateConstraints(Binder &binder, LogicalGet &, LogicalProjection &, LogicalUpdate &, + ClientContext &) { + throw NotImplementedException("BindUpdateConstraints"); } TableFunction UCTableEntry::GetScanFunction(ClientContext &context, unique_ptr &bind_data) { - auto &db = DatabaseInstance::GetDatabase(context); - auto &delta_function_set = ExtensionUtil::GetTableFunction(db, "delta_scan"); - auto delta_scan_function = delta_function_set.functions.GetFunctionByArguments(context, {LogicalType::VARCHAR}); + auto &db = DatabaseInstance::GetDatabase(context); + auto &delta_function_set = ExtensionUtil::GetTableFunction(db, "delta_scan"); + auto delta_scan_function = delta_function_set.functions.GetFunctionByArguments(context, {LogicalType::VARCHAR}); auto &uc_catalog = catalog.Cast(); - D_ASSERT(table_data); + D_ASSERT(table_data); if (table_data->data_source_format != "DELTA") { - throw NotImplementedException("Table '%s' is of unsupported format '%s', ", table_data->name, table_data->data_source_format); + throw NotImplementedException("Table '%s' is of unsupported format '%s', ", table_data->name, + table_data->data_source_format); } // Set the S3 path as input to table function - vector inputs = {table_data->storage_location}; - - if (table_data->storage_location.find("file://") != 0) { - auto &secret_manager = SecretManager::Get(context); - // Get Credentials from UCAPI - auto table_credentials = UCAPI::GetTableCredentials(table_data->table_id, uc_catalog.credentials); - - // Inject secret into secret manager scoped to this path - CreateSecretInfo info(OnCreateConflict::REPLACE_ON_CONFLICT, SecretPersistType::TEMPORARY); - info.name = "__internal_uc_" + table_data->table_id; - info.type = "s3"; - info.provider = "config"; - info.options = { - {"key_id", table_credentials.key_id}, - {"secret", table_credentials.secret}, - {"session_token", table_credentials.session_token}, - {"region", uc_catalog.credentials.aws_region}, - }; - info.scope = {table_data->storage_location}; - secret_manager.CreateSecret(context, info); - } - named_parameter_map_t param_map; - vector return_types; - vector names; - TableFunctionRef empty_ref; - - TableFunctionBindInput bind_input(inputs, param_map, return_types, names, nullptr, nullptr, delta_scan_function, empty_ref); - - auto result = delta_scan_function.bind(context, bind_input, return_types, names); - bind_data = std::move(result); - - return delta_scan_function; + vector inputs = {table_data->storage_location}; + + if (table_data->storage_location.find("file://") != 0) { + auto &secret_manager = SecretManager::Get(context); + // Get Credentials from UCAPI + auto table_credentials = UCAPI::GetTableCredentials(table_data->table_id, uc_catalog.credentials); + + // Inject secret into secret manager scoped to this path + CreateSecretInfo info(OnCreateConflict::REPLACE_ON_CONFLICT, SecretPersistType::TEMPORARY); + info.name = "__internal_uc_" + table_data->table_id; + info.type = "s3"; + info.provider = "config"; + info.options = { + {"key_id", table_credentials.key_id}, + {"secret", table_credentials.secret}, + {"session_token", table_credentials.session_token}, + {"region", uc_catalog.credentials.aws_region}, + }; + info.scope = {table_data->storage_location}; + secret_manager.CreateSecret(context, info); + } + named_parameter_map_t param_map; + vector return_types; + vector names; + TableFunctionRef empty_ref; + + TableFunctionBindInput bind_input(inputs, param_map, return_types, names, nullptr, nullptr, delta_scan_function, + empty_ref); + + auto result = delta_scan_function.bind(context, bind_input, return_types, names); + bind_data = std::move(result); + + return delta_scan_function; } TableStorageInfo UCTableEntry::GetStorageInfo(ClientContext &context) { diff --git a/src/storage/uc_table_set.cpp b/src/storage/uc_table_set.cpp index 9374f1e..7819338 100644 --- a/src/storage/uc_table_set.cpp +++ b/src/storage/uc_table_set.cpp @@ -21,30 +21,30 @@ UCTableSet::UCTableSet(UCSchemaEntry &schema) : UCInSchemaSet(schema) { } static ColumnDefinition CreateColumnDefinition(ClientContext &context, UCAPIColumnDefinition &coldef) { - return {coldef.name, UCUtils::TypeToLogicalType(context, coldef.type_text)}; + return {coldef.name, UCUtils::TypeToLogicalType(context, coldef.type_text)}; } void UCTableSet::LoadEntries(ClientContext &context) { auto &transaction = UCTransaction::Get(context, catalog); - auto &uc_catalog = catalog.Cast(); + auto &uc_catalog = catalog.Cast(); - // TODO: handle out-of-order columns using position property - auto tables = UCAPI::GetTables(catalog.GetDBPath(), schema.name, uc_catalog.credentials); + // TODO: handle out-of-order columns using position property + auto tables = UCAPI::GetTables(catalog.GetDBPath(), schema.name, uc_catalog.credentials); - for (auto &table : tables) { - D_ASSERT(schema.name == table.schema_name); - CreateTableInfo info; - for (auto &col : table.columns) { - info.columns.AddColumn(CreateColumnDefinition(context, col)); - } + for (auto &table : tables) { + D_ASSERT(schema.name == table.schema_name); + CreateTableInfo info; + for (auto &col : table.columns) { + info.columns.AddColumn(CreateColumnDefinition(context, col)); + } - info.table = table.name; - auto table_entry = make_uniq(catalog, schema, info); - table_entry->table_data = make_uniq(table); + info.table = table.name; + auto table_entry = make_uniq(catalog, schema, info); + table_entry->table_data = make_uniq(table); - CreateEntry(std::move(table_entry)); - } + CreateEntry(std::move(table_entry)); + } } optional_ptr UCTableSet::RefreshTable(ClientContext &context, const string &table_name) { @@ -56,32 +56,32 @@ optional_ptr UCTableSet::RefreshTable(ClientContext &context, cons } unique_ptr UCTableSet::GetTableInfo(ClientContext &context, UCSchemaEntry &schema, - const string &table_name) { - throw NotImplementedException("UCTableSet::CreateTable"); + const string &table_name) { + throw NotImplementedException("UCTableSet::CreateTable"); } optional_ptr UCTableSet::CreateTable(ClientContext &context, BoundCreateTableInfo &info) { - throw NotImplementedException("UCTableSet::CreateTable"); + throw NotImplementedException("UCTableSet::CreateTable"); } void UCTableSet::AlterTable(ClientContext &context, RenameTableInfo &info) { - throw NotImplementedException("UCTableSet::AlterTable"); + throw NotImplementedException("UCTableSet::AlterTable"); } void UCTableSet::AlterTable(ClientContext &context, RenameColumnInfo &info) { - throw NotImplementedException("UCTableSet::AlterTable"); + throw NotImplementedException("UCTableSet::AlterTable"); } void UCTableSet::AlterTable(ClientContext &context, AddColumnInfo &info) { - throw NotImplementedException("UCTableSet::AlterTable"); + throw NotImplementedException("UCTableSet::AlterTable"); } void UCTableSet::AlterTable(ClientContext &context, RemoveColumnInfo &info) { - throw NotImplementedException("UCTableSet::AlterTable"); + throw NotImplementedException("UCTableSet::AlterTable"); } void UCTableSet::AlterTable(ClientContext &context, AlterTableInfo &alter) { - throw NotImplementedException("UCTableSet::AlterTable"); + throw NotImplementedException("UCTableSet::AlterTable"); } } // namespace duckdb diff --git a/src/storage/uc_transaction.cpp b/src/storage/uc_transaction.cpp index 00e074c..2942458 100644 --- a/src/storage/uc_transaction.cpp +++ b/src/storage/uc_transaction.cpp @@ -8,7 +8,7 @@ namespace duckdb { UCTransaction::UCTransaction(UCCatalog &uc_catalog, TransactionManager &manager, ClientContext &context) : Transaction(manager, context), access_mode(uc_catalog.access_mode) { -// connection = UCConnection::Open(uc_catalog.path); + // connection = UCConnection::Open(uc_catalog.path); } UCTransaction::~UCTransaction() = default; @@ -19,17 +19,17 @@ void UCTransaction::Start() { void UCTransaction::Commit() { if (transaction_state == UCTransactionState::TRANSACTION_STARTED) { transaction_state = UCTransactionState::TRANSACTION_FINISHED; -// connection.Execute("COMMIT"); + // connection.Execute("COMMIT"); } } void UCTransaction::Rollback() { if (transaction_state == UCTransactionState::TRANSACTION_STARTED) { transaction_state = UCTransactionState::TRANSACTION_FINISHED; -// connection.Execute("ROLLBACK"); + // connection.Execute("ROLLBACK"); } } -//UCConnection &UCTransaction::GetConnection() { +// UCConnection &UCTransaction::GetConnection() { // if (transaction_state == UCTransactionState::TRANSACTION_NOT_YET_STARTED) { // transaction_state = UCTransactionState::TRANSACTION_STARTED; // string query = "START TRANSACTION"; @@ -41,7 +41,7 @@ void UCTransaction::Rollback() { // return connection; //} -//unique_ptr UCTransaction::Query(const string &query) { +// unique_ptr UCTransaction::Query(const string &query) { // if (transaction_state == UCTransactionState::TRANSACTION_NOT_YET_STARTED) { // transaction_state = UCTransactionState::TRANSACTION_STARTED; // string transaction_start = "START TRANSACTION"; diff --git a/src/storage/uc_transaction_manager.cpp b/src/storage/uc_transaction_manager.cpp index 0c941dc..afd7e07 100644 --- a/src/storage/uc_transaction_manager.cpp +++ b/src/storage/uc_transaction_manager.cpp @@ -33,8 +33,8 @@ void UCTransactionManager::RollbackTransaction(Transaction &transaction) { void UCTransactionManager::Checkpoint(ClientContext &context, bool force) { auto &transaction = UCTransaction::Get(context, db.GetCatalog()); -// auto &db = transaction.GetConnection(); -// db.Execute("CHECKPOINT"); + // auto &db = transaction.GetConnection(); + // db.Execute("CHECKPOINT"); } } // namespace duckdb diff --git a/src/uc_api.cpp b/src/uc_api.cpp index 2e575c7..0ddcc85 100644 --- a/src/uc_api.cpp +++ b/src/uc_api.cpp @@ -3,99 +3,104 @@ #include "yyjson.hpp" #include - namespace duckdb { -static size_t GetRequestWriteCallback(void *contents, size_t size, size_t nmemb, void *userp) -{ - ((std::string*)userp)->append((char*)contents, size * nmemb); - return size * nmemb; +static size_t GetRequestWriteCallback(void *contents, size_t size, size_t nmemb, void *userp) { + ((std::string *)userp)->append((char *)contents, size * nmemb); + return size * nmemb; } -static string GetRequest(const string& url, const string& token = ""){ - CURL *curl; - CURLcode res; - string readBuffer; - - curl = curl_easy_init(); - if(curl) { - curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); - curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, GetRequestWriteCallback); - curl_easy_setopt(curl, CURLOPT_WRITEDATA, &readBuffer); - if (!token.empty()) { - curl_easy_setopt(curl, CURLOPT_XOAUTH2_BEARER, token.c_str()); - curl_easy_setopt(curl, CURLOPT_HTTPAUTH, CURLAUTH_BEARER); - } - res = curl_easy_perform(curl); - curl_easy_cleanup(curl); - - if (res != CURLcode::CURLE_OK) { - string error = curl_easy_strerror(res); - throw IOException("Curl Request to '%s' failed with error: '%s'", url, error); - } - return readBuffer; - } - throw InternalException("Failed to initialize curl"); +static string GetRequest(const string &url, const string &token = "") { + CURL *curl; + CURLcode res; + string readBuffer; + + curl = curl_easy_init(); + if (curl) { + curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, GetRequestWriteCallback); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &readBuffer); + if (!token.empty()) { + curl_easy_setopt(curl, CURLOPT_XOAUTH2_BEARER, token.c_str()); + curl_easy_setopt(curl, CURLOPT_HTTPAUTH, CURLAUTH_BEARER); + } + res = curl_easy_perform(curl); + curl_easy_cleanup(curl); + + if (res != CURLcode::CURLE_OK) { + string error = curl_easy_strerror(res); + throw IOException("Curl Request to '%s' failed with error: '%s'", url, error); + } + return readBuffer; + } + throw InternalException("Failed to initialize curl"); } template -static TYPE TemplatedTryGetYYJson(duckdb_yyjson::yyjson_val *obj, const string &field, TYPE default_val, bool fail_on_missing = true) { - auto val = yyjson_obj_get(obj, field.c_str()); - if (val && yyjson_get_type(val) == TYPE_NUM) { - return get_function(val); - } else if (!fail_on_missing) { - return default_val; - } - throw IOException("Invalid field found while parsing field: " + field); +static TYPE TemplatedTryGetYYJson(duckdb_yyjson::yyjson_val *obj, const string &field, TYPE default_val, + bool fail_on_missing = true) { + auto val = yyjson_obj_get(obj, field.c_str()); + if (val && yyjson_get_type(val) == TYPE_NUM) { + return get_function(val); + } else if (!fail_on_missing) { + return default_val; + } + throw IOException("Invalid field found while parsing field: " + field); } -static uint64_t TryGetNumFromObject(duckdb_yyjson::yyjson_val *obj, const string &field, bool fail_on_missing = true, uint64_t default_val = 0) { - return TemplatedTryGetYYJson(obj, field, default_val, fail_on_missing); +static uint64_t TryGetNumFromObject(duckdb_yyjson::yyjson_val *obj, const string &field, bool fail_on_missing = true, + uint64_t default_val = 0) { + return TemplatedTryGetYYJson(obj, field, default_val, + fail_on_missing); } -static bool TryGetBoolFromObject(duckdb_yyjson::yyjson_val *obj, const string &field, bool fail_on_missing = false, bool default_val = false) { - return TemplatedTryGetYYJson(obj, field, default_val, fail_on_missing); +static bool TryGetBoolFromObject(duckdb_yyjson::yyjson_val *obj, const string &field, bool fail_on_missing = false, + bool default_val = false) { + return TemplatedTryGetYYJson(obj, field, default_val, + fail_on_missing); } -static string TryGetStrFromObject(duckdb_yyjson::yyjson_val *obj, const string &field, bool fail_on_missing = true, const char* default_val = "") { - return TemplatedTryGetYYJson(obj, field, default_val, fail_on_missing); +static string TryGetStrFromObject(duckdb_yyjson::yyjson_val *obj, const string &field, bool fail_on_missing = true, + const char *default_val = "") { + return TemplatedTryGetYYJson(obj, field, default_val, + fail_on_missing); } -static string GetCredentialsRequest(const string& url, const string &table_id, const string& token = ""){ - CURL *curl; - CURLcode res; - string readBuffer; - - string body = StringUtil::Format(R"({"table_id" : "%s", "operation" : "READ"})", table_id); - - curl = curl_easy_init(); - if(curl) { - curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); - curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, GetRequestWriteCallback); - curl_easy_setopt(curl, CURLOPT_WRITEDATA, &readBuffer); - - // Set headers - struct curl_slist *headers = curl_slist_append(nullptr, "Content-Type: application/json"); - curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); - - // Set token - if (!token.empty()) { - curl_easy_setopt(curl, CURLOPT_XOAUTH2_BEARER, token.c_str()); - curl_easy_setopt(curl, CURLOPT_HTTPAUTH, CURLAUTH_BEARER); - } - - // Set request body - curl_easy_setopt(curl, CURLOPT_POSTFIELDS, body.c_str()); - curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE, body.length()); - - res = curl_easy_perform(curl); - curl_easy_cleanup(curl); - - if (res != CURLcode::CURLE_OK) { - string error = curl_easy_strerror(res); - throw IOException("Curl Request to '%s' failed with error: '%s'", url, error); - } - return readBuffer; - } - throw InternalException("Failed to initialize curl"); +static string GetCredentialsRequest(const string &url, const string &table_id, const string &token = "") { + CURL *curl; + CURLcode res; + string readBuffer; + + string body = StringUtil::Format(R"({"table_id" : "%s", "operation" : "READ"})", table_id); + + curl = curl_easy_init(); + if (curl) { + curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, GetRequestWriteCallback); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &readBuffer); + + // Set headers + struct curl_slist *headers = curl_slist_append(nullptr, "Content-Type: application/json"); + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); + + // Set token + if (!token.empty()) { + curl_easy_setopt(curl, CURLOPT_XOAUTH2_BEARER, token.c_str()); + curl_easy_setopt(curl, CURLOPT_HTTPAUTH, CURLAUTH_BEARER); + } + + // Set request body + curl_easy_setopt(curl, CURLOPT_POSTFIELDS, body.c_str()); + curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE, body.length()); + + res = curl_easy_perform(curl); + curl_easy_cleanup(curl); + + if (res != CURLcode::CURLE_OK) { + string error = curl_easy_strerror(res); + throw IOException("Curl Request to '%s' failed with error: '%s'", url, error); + } + return readBuffer; + } + throw InternalException("Failed to initialize curl"); } //# list catalogs @@ -110,108 +115,112 @@ static string GetCredentialsRequest(const string& url, const string &table_id, c // //# list tables in `default` schema // echo "Tables in default schema" -// curl --request GET "https://${DATABRICKS_HOST}/api/2.1/unity-catalog/tables?catalog_name=workspace&schema_name=default" \ +// curl --request GET +// "https://${DATABRICKS_HOST}/api/2.1/unity-catalog/tables?catalog_name=workspace&schema_name=default" \ // --header "Authorization: Bearer ${TOKEN}" | jq . - UCAPITableCredentials UCAPI::GetTableCredentials(const string &table_id, UCCredentials credentials) { - UCAPITableCredentials result; + UCAPITableCredentials result; - auto api_result = GetCredentialsRequest(credentials.endpoint + "/api/2.1/unity-catalog/temporary-table-credentials", table_id, credentials.token); + auto api_result = GetCredentialsRequest(credentials.endpoint + "/api/2.1/unity-catalog/temporary-table-credentials", + table_id, credentials.token); - // Read JSON and get root - duckdb_yyjson::yyjson_doc *doc = duckdb_yyjson::yyjson_read(api_result.c_str(), api_result.size(), 0); - duckdb_yyjson::yyjson_val *root = yyjson_doc_get_root(doc); + // Read JSON and get root + duckdb_yyjson::yyjson_doc *doc = duckdb_yyjson::yyjson_read(api_result.c_str(), api_result.size(), 0); + duckdb_yyjson::yyjson_val *root = yyjson_doc_get_root(doc); - auto *aws_temp_credentials = yyjson_obj_get(root, "aws_temp_credentials"); - if (aws_temp_credentials) { - result.key_id = TryGetStrFromObject(aws_temp_credentials, "access_key_id"); - result.secret = TryGetStrFromObject(aws_temp_credentials, "secret_access_key"); - result.session_token = TryGetStrFromObject(aws_temp_credentials, "session_token"); - } + auto *aws_temp_credentials = yyjson_obj_get(root, "aws_temp_credentials"); + if (aws_temp_credentials) { + result.key_id = TryGetStrFromObject(aws_temp_credentials, "access_key_id"); + result.secret = TryGetStrFromObject(aws_temp_credentials, "secret_access_key"); + result.session_token = TryGetStrFromObject(aws_temp_credentials, "session_token"); + } - return result; + return result; } vector UCAPI::GetCatalogs(const string &catalog, UCCredentials credentials) { - throw NotImplementedException("UCAPI::GetCatalogs"); + throw NotImplementedException("UCAPI::GetCatalogs"); } static UCAPIColumnDefinition ParseColumnDefinition(duckdb_yyjson::yyjson_val *column_def) { - UCAPIColumnDefinition result; + UCAPIColumnDefinition result; - result.name = TryGetStrFromObject(column_def, "name"); - result.type_text = TryGetStrFromObject(column_def, "type_text"); - result.precision = TryGetNumFromObject(column_def, "type_precision"); - result.scale = TryGetNumFromObject(column_def, "type_scale"); - result.position = TryGetNumFromObject(column_def, "position"); + result.name = TryGetStrFromObject(column_def, "name"); + result.type_text = TryGetStrFromObject(column_def, "type_text"); + result.precision = TryGetNumFromObject(column_def, "type_precision"); + result.scale = TryGetNumFromObject(column_def, "type_scale"); + result.position = TryGetNumFromObject(column_def, "position"); - return result; + return result; } vector UCAPI::GetTables(const string &catalog, const string &schema, UCCredentials credentials) { - vector result; - auto api_result = GetRequest(credentials.endpoint + "/api/2.1/unity-catalog/tables?catalog_name=" + catalog + "&schema_name=" + schema, credentials.token); - - // Read JSON and get root - duckdb_yyjson::yyjson_doc *doc = duckdb_yyjson::yyjson_read(api_result.c_str(), api_result.size(), 0); - duckdb_yyjson::yyjson_val *root = yyjson_doc_get_root(doc); - - // Get root["hits"], iterate over the array - auto *tables = yyjson_obj_get(root, "tables"); - size_t idx, max; - duckdb_yyjson::yyjson_val *table; - yyjson_arr_foreach(tables, idx, max, table) { - UCAPITable table_result; - table_result.catalog_name = catalog; - table_result.schema_name = schema; - - table_result.name = TryGetStrFromObject(table, "name"); - table_result.table_type = TryGetStrFromObject(table, "table_type"); - table_result.data_source_format = TryGetStrFromObject(table, "data_source_format", false); - table_result.storage_location = TryGetStrFromObject(table, "storage_location", false); - table_result.table_id = TryGetStrFromObject(table, "table_id"); - - auto *columns = yyjson_obj_get(table, "columns"); - duckdb_yyjson::yyjson_val *col; - size_t col_idx, col_max; - yyjson_arr_foreach(columns, col_idx, col_max, col) { - auto column_definition = ParseColumnDefinition(col); - table_result.columns.push_back(column_definition); - } - - result.push_back(table_result); - } - - return result; + vector result; + auto api_result = GetRequest(credentials.endpoint + "/api/2.1/unity-catalog/tables?catalog_name=" + catalog + + "&schema_name=" + schema, + credentials.token); + + // Read JSON and get root + duckdb_yyjson::yyjson_doc *doc = duckdb_yyjson::yyjson_read(api_result.c_str(), api_result.size(), 0); + duckdb_yyjson::yyjson_val *root = yyjson_doc_get_root(doc); + + // Get root["hits"], iterate over the array + auto *tables = yyjson_obj_get(root, "tables"); + size_t idx, max; + duckdb_yyjson::yyjson_val *table; + yyjson_arr_foreach(tables, idx, max, table) { + UCAPITable table_result; + table_result.catalog_name = catalog; + table_result.schema_name = schema; + + table_result.name = TryGetStrFromObject(table, "name"); + table_result.table_type = TryGetStrFromObject(table, "table_type"); + table_result.data_source_format = TryGetStrFromObject(table, "data_source_format", false); + table_result.storage_location = TryGetStrFromObject(table, "storage_location", false); + table_result.table_id = TryGetStrFromObject(table, "table_id"); + + auto *columns = yyjson_obj_get(table, "columns"); + duckdb_yyjson::yyjson_val *col; + size_t col_idx, col_max; + yyjson_arr_foreach(columns, col_idx, col_max, col) { + auto column_definition = ParseColumnDefinition(col); + table_result.columns.push_back(column_definition); + } + + result.push_back(table_result); + } + + return result; } vector UCAPI::GetSchemas(const string &catalog, UCCredentials credentials) { - vector result; + vector result; - auto api_result = GetRequest(credentials.endpoint + "/api/2.1/unity-catalog/schemas?catalog_name=" + catalog, credentials.token); + auto api_result = + GetRequest(credentials.endpoint + "/api/2.1/unity-catalog/schemas?catalog_name=" + catalog, credentials.token); - // Read JSON and get root - duckdb_yyjson::yyjson_doc *doc = duckdb_yyjson::yyjson_read(api_result.c_str(), api_result.size(), 0); - duckdb_yyjson::yyjson_val *root = yyjson_doc_get_root(doc); + // Read JSON and get root + duckdb_yyjson::yyjson_doc *doc = duckdb_yyjson::yyjson_read(api_result.c_str(), api_result.size(), 0); + duckdb_yyjson::yyjson_val *root = yyjson_doc_get_root(doc); - // Get root["hits"], iterate over the array - auto *schemas = yyjson_obj_get(root, "schemas"); - size_t idx, max; - duckdb_yyjson::yyjson_val *schema; - yyjson_arr_foreach(schemas, idx, max, schema) { - UCAPISchema schema_result; + // Get root["hits"], iterate over the array + auto *schemas = yyjson_obj_get(root, "schemas"); + size_t idx, max; + duckdb_yyjson::yyjson_val *schema; + yyjson_arr_foreach(schemas, idx, max, schema) { + UCAPISchema schema_result; - auto *name = yyjson_obj_get(schema, "name"); - if (name) { - schema_result.schema_name = yyjson_get_str(name); - } - schema_result.catalog_name = catalog; + auto *name = yyjson_obj_get(schema, "name"); + if (name) { + schema_result.schema_name = yyjson_get_str(name); + } + schema_result.catalog_name = catalog; - result.push_back(schema_result); - } + result.push_back(schema_result); + } - return result; + return result; } } // namespace duckdb diff --git a/src/uc_catalog_extension.cpp b/src/uc_catalog_extension.cpp index 7dbdebb..6fcfe8e 100644 --- a/src/uc_catalog_extension.cpp +++ b/src/uc_catalog_extension.cpp @@ -4,7 +4,6 @@ #include "storage/uc_catalog.hpp" #include "storage/uc_transaction_manager.hpp" - #include "duckdb.hpp" #include "duckdb/main/secret/secret_manager.hpp" #include "duckdb/common/exception.hpp" @@ -18,143 +17,145 @@ namespace duckdb { static unique_ptr CreateUCSecretFunction(ClientContext &, CreateSecretInput &input) { - // apply any overridden settings - vector prefix_paths; - auto result = make_uniq(prefix_paths, "uc", "config", input.name); - for (const auto &named_param : input.options) { - auto lower_name = StringUtil::Lower(named_param.first); - - if (lower_name == "token") { - result->secret_map["token"] = named_param.second.ToString(); - } else if (lower_name == "endpoint") { - result->secret_map["endpoint"] = named_param.second.ToString(); - } else if (lower_name == "aws_region") { - result->secret_map["aws_region"] = named_param.second.ToString(); - } else { - throw InternalException("Unknown named parameter passed to CreateUCSecretFunction: " + lower_name); - } - } - - //! Set redact keys - result->redact_keys = {"token"}; - - return std::move(result); + // apply any overridden settings + vector prefix_paths; + auto result = make_uniq(prefix_paths, "uc", "config", input.name); + for (const auto &named_param : input.options) { + auto lower_name = StringUtil::Lower(named_param.first); + + if (lower_name == "token") { + result->secret_map["token"] = named_param.second.ToString(); + } else if (lower_name == "endpoint") { + result->secret_map["endpoint"] = named_param.second.ToString(); + } else if (lower_name == "aws_region") { + result->secret_map["aws_region"] = named_param.second.ToString(); + } else { + throw InternalException("Unknown named parameter passed to CreateUCSecretFunction: " + lower_name); + } + } + + //! Set redact keys + result->redact_keys = {"token"}; + + return std::move(result); } static void SetUCSecretParameters(CreateSecretFunction &function) { - function.named_parameters["token"] = LogicalType::VARCHAR; - function.named_parameters["endpoint"] = LogicalType::VARCHAR; - function.named_parameters["aws_region"] = LogicalType::VARCHAR; + function.named_parameters["token"] = LogicalType::VARCHAR; + function.named_parameters["endpoint"] = LogicalType::VARCHAR; + function.named_parameters["aws_region"] = LogicalType::VARCHAR; } unique_ptr GetSecret(ClientContext &context, const string &secret_name) { - auto &secret_manager = SecretManager::Get(context); - auto transaction = CatalogTransaction::GetSystemCatalogTransaction(context); - // FIXME: this should be adjusted once the `GetSecretByName` API supports this - // use case - auto secret_entry = secret_manager.GetSecretByName(transaction, secret_name, "memory"); - if (secret_entry) { - return secret_entry; - } - secret_entry = secret_manager.GetSecretByName(transaction, secret_name, "local_file"); - if (secret_entry) { - return secret_entry; - } - return nullptr; + auto &secret_manager = SecretManager::Get(context); + auto transaction = CatalogTransaction::GetSystemCatalogTransaction(context); + // FIXME: this should be adjusted once the `GetSecretByName` API supports this + // use case + auto secret_entry = secret_manager.GetSecretByName(transaction, secret_name, "memory"); + if (secret_entry) { + return secret_entry; + } + secret_entry = secret_manager.GetSecretByName(transaction, secret_name, "local_file"); + if (secret_entry) { + return secret_entry; + } + return nullptr; } -static unique_ptr UCCatalogAttach(StorageExtensionInfo *storage_info, - ClientContext &context, - AttachedDatabase &db, - const string &name, AttachInfo &info, +static unique_ptr UCCatalogAttach(StorageExtensionInfo *storage_info, ClientContext &context, + AttachedDatabase &db, const string &name, AttachInfo &info, AccessMode access_mode) { - UCCredentials credentials; - - // check if we have a secret provided - string secret_name; - for (auto &entry : info.options) { - auto lower_name = StringUtil::Lower(entry.first); - if (lower_name == "type" || lower_name == "read_only") { - // already handled - } else if (lower_name == "secret") { - secret_name = entry.second.ToString(); - } else { - throw BinderException("Unrecognized option for UC attach: %s", entry.first); - } - } - - // if no secret is specified we default to the unnamed mysql secret, if it - // exists - bool explicit_secret = !secret_name.empty(); - if (!explicit_secret) { - // look up settings from the default unnamed mysql secret if none is - // provided - secret_name = "__default_uc"; - } - - string connection_string = info.path; - auto secret_entry = GetSecret(context, secret_name); - if (secret_entry) { - // secret found - read data - const auto &kv_secret = dynamic_cast(*secret_entry->secret); - string new_connection_info; - - Value input_val = kv_secret.TryGetValue("token"); - credentials.token = input_val.IsNull() ? "" : input_val.ToString(); - - Value endpoint_val = kv_secret.TryGetValue("endpoint"); - credentials.endpoint = endpoint_val.IsNull() ? "" : endpoint_val.ToString(); - StringUtil::RTrim(credentials.endpoint, "/"); - - Value aws_region_val = kv_secret.TryGetValue("aws_region"); - credentials.aws_region = endpoint_val.IsNull() ? "" : aws_region_val.ToString(); - - } else if (explicit_secret) { - // secret not found and one was explicitly provided - throw an error - throw BinderException("Secret with name \"%s\" not found", secret_name); - } - - return make_uniq(db, info.path, access_mode, credentials); + UCCredentials credentials; + + // check if we have a secret provided + string secret_name; + for (auto &entry : info.options) { + auto lower_name = StringUtil::Lower(entry.first); + if (lower_name == "type" || lower_name == "read_only") { + // already handled + } else if (lower_name == "secret") { + secret_name = entry.second.ToString(); + } else { + throw BinderException("Unrecognized option for UC attach: %s", entry.first); + } + } + + // if no secret is specified we default to the unnamed mysql secret, if it + // exists + bool explicit_secret = !secret_name.empty(); + if (!explicit_secret) { + // look up settings from the default unnamed mysql secret if none is + // provided + secret_name = "__default_uc"; + } + + string connection_string = info.path; + auto secret_entry = GetSecret(context, secret_name); + if (secret_entry) { + // secret found - read data + const auto &kv_secret = dynamic_cast(*secret_entry->secret); + string new_connection_info; + + Value input_val = kv_secret.TryGetValue("token"); + credentials.token = input_val.IsNull() ? "" : input_val.ToString(); + + Value endpoint_val = kv_secret.TryGetValue("endpoint"); + credentials.endpoint = endpoint_val.IsNull() ? "" : endpoint_val.ToString(); + StringUtil::RTrim(credentials.endpoint, "/"); + + Value aws_region_val = kv_secret.TryGetValue("aws_region"); + credentials.aws_region = endpoint_val.IsNull() ? "" : aws_region_val.ToString(); + + } else if (explicit_secret) { + // secret not found and one was explicitly provided - throw an error + throw BinderException("Secret with name \"%s\" not found", secret_name); + } + + return make_uniq(db, info.path, access_mode, credentials); } -static unique_ptr CreateTransactionManager(StorageExtensionInfo *storage_info, - AttachedDatabase &db, Catalog &catalog) { - auto &uc_catalog = catalog.Cast(); - return make_uniq(db, uc_catalog); +static unique_ptr CreateTransactionManager(StorageExtensionInfo *storage_info, AttachedDatabase &db, + Catalog &catalog) { + auto &uc_catalog = catalog.Cast(); + return make_uniq(db, uc_catalog); } class UCCatalogStorageExtension : public StorageExtension { public: - UCCatalogStorageExtension() { - attach = UCCatalogAttach; - create_transaction_manager = CreateTransactionManager; - } + UCCatalogStorageExtension() { + attach = UCCatalogAttach; + create_transaction_manager = CreateTransactionManager; + } }; static void LoadInternal(DatabaseInstance &instance) { - SecretType secret_type; - secret_type.name = "uc"; - secret_type.deserializer = KeyValueSecret::Deserialize; - secret_type.default_provider = "config"; + SecretType secret_type; + secret_type.name = "uc"; + secret_type.deserializer = KeyValueSecret::Deserialize; + secret_type.default_provider = "config"; - ExtensionUtil::RegisterSecretType(instance, secret_type); + ExtensionUtil::RegisterSecretType(instance, secret_type); - CreateSecretFunction mysql_secret_function = {"uc", "config", CreateUCSecretFunction}; - SetUCSecretParameters(mysql_secret_function); - ExtensionUtil::RegisterFunction(instance, mysql_secret_function); + CreateSecretFunction mysql_secret_function = {"uc", "config", CreateUCSecretFunction}; + SetUCSecretParameters(mysql_secret_function); + ExtensionUtil::RegisterFunction(instance, mysql_secret_function); - auto &config = DBConfig::GetConfig(instance); - config.storage_extensions["uc_catalog"] = make_uniq(); + auto &config = DBConfig::GetConfig(instance); + config.storage_extensions["uc_catalog"] = make_uniq(); } -void UcCatalogExtension::Load(DuckDB &db) { LoadInternal(*db.instance); } -std::string UcCatalogExtension::Name() { return "uc_catalog"; } +void UcCatalogExtension::Load(DuckDB &db) { + LoadInternal(*db.instance); +} +std::string UcCatalogExtension::Name() { + return "uc_catalog"; +} std::string UcCatalogExtension::Version() const { #ifdef EXT_VERSION_UC_CATALOG - return EXT_VERSION_UC_CATALOG; + return EXT_VERSION_UC_CATALOG; #else - return ""; + return ""; #endif } @@ -163,12 +164,12 @@ std::string UcCatalogExtension::Version() const { extern "C" { DUCKDB_EXTENSION_API void uc_catalog_init(duckdb::DatabaseInstance &db) { - duckdb::DuckDB db_wrapper(db); - db_wrapper.LoadExtension(); + duckdb::DuckDB db_wrapper(db); + db_wrapper.LoadExtension(); } DUCKDB_EXTENSION_API const char *uc_catalog_version() { - return duckdb::DuckDB::LibraryVersion(); + return duckdb::DuckDB::LibraryVersion(); } } diff --git a/src/uc_utils.cpp b/src/uc_utils.cpp index 8fc73dc..f746635 100644 --- a/src/uc_utils.cpp +++ b/src/uc_utils.cpp @@ -30,127 +30,127 @@ string UCUtils::TypeToString(const LogicalType &input) { LogicalType UCUtils::TypeToLogicalType(ClientContext &context, const string &type_text) { if (type_text == "tinyint") { - return LogicalType::TINYINT; + return LogicalType::TINYINT; } else if (type_text == "smallint") { - return LogicalType::SMALLINT; - } else if (type_text == "bigint") { - return LogicalType::BIGINT; - } else if (type_text == "int") { - return LogicalType::INTEGER; + return LogicalType::SMALLINT; + } else if (type_text == "bigint") { + return LogicalType::BIGINT; + } else if (type_text == "int") { + return LogicalType::INTEGER; } else if (type_text == "long") { - return LogicalType::BIGINT; + return LogicalType::BIGINT; } else if (type_text == "string") { - return LogicalType::VARCHAR; + return LogicalType::VARCHAR; } else if (type_text == "double") { - return LogicalType::DOUBLE; + return LogicalType::DOUBLE; } else if (type_text == "float") { - return LogicalType::FLOAT; + return LogicalType::FLOAT; } else if (type_text == "boolean") { - return LogicalType::BOOLEAN; + return LogicalType::BOOLEAN; } else if (type_text == "timestamp") { - return LogicalType::TIMESTAMP; - } else if (type_text == "binary") { - return LogicalType::BLOB; - } else if (type_text == "date") { - return LogicalType::DATE; - } else if (type_text == "timestamp") { - return LogicalType::TIMESTAMP; // TODO: Is this the right timestamp - } else if (type_text.find("decimal(") == 0) { - size_t spec_end = type_text.find(')'); - if (spec_end != string::npos) { - size_t sep = type_text.find(','); - auto prec_str = type_text.substr(8, sep - 8); - auto scale_str = type_text.substr(sep + 1, spec_end - sep - 1); - uint8_t prec = Cast::Operation(prec_str); - uint8_t scale = Cast::Operation(scale_str); - return LogicalType::DECIMAL(prec, scale); - } - } else if (type_text.find("array<") == 0) { - size_t type_end = type_text.rfind('>'); // find last, to deal with nested - if (type_end != string::npos) { - auto child_type_str = type_text.substr(6, type_end - 6); - auto child_type = UCUtils::TypeToLogicalType(context, child_type_str); - return LogicalType::LIST(child_type); - } - } else if (type_text.find("map<") == 0) { - size_t type_end = type_text.rfind('>'); // find last, to deal with nested - if (type_end != string::npos) { - // TODO: Factor this and struct parsing into an iterator over ',' separated values - vector key_val; - size_t cur = 4; - auto nested_opens = 0; - for (;;) { - size_t next_sep = cur; - // find the location of the next ',' ignoring nested commas - while (type_text[next_sep] != ',' || nested_opens > 0) { - if (type_text[next_sep] == '<') { - nested_opens++; - } else if (type_text[next_sep] == '>') { - nested_opens--; - } - next_sep++; - if (next_sep == type_end) { - break; - } - } - auto child_str = type_text.substr(cur, next_sep - cur); - auto child_type = UCUtils::TypeToLogicalType(context, child_str); - key_val.push_back(child_type); - if (next_sep == type_end) { - break; - } - cur = next_sep+1; - } - if (key_val.size() != 2) { - throw NotImplementedException("Invalid map specification with %i types", key_val.size()); - } - return LogicalType::MAP(key_val[0], key_val[1]); - } - } else if (type_text.find("struct<") == 0) { - size_t type_end = type_text.rfind('>'); // find last, to deal with nested - if (type_end != string::npos) { - child_list_t children; - size_t cur = 7; - auto nested_opens = 0; - for (;;) { - size_t next_sep = cur; - // find the location of the next ',' ignoring nested commas - while (type_text[next_sep] != ',' || nested_opens > 0) { - if (type_text[next_sep] == '<') { - nested_opens++; - } else if (type_text[next_sep] == '>') { - nested_opens--; - } - next_sep++; - if (next_sep == type_end) { - break; - } - } - auto child_str = type_text.substr(cur, next_sep - cur); - size_t type_sep = child_str.find(':'); - if (type_sep == string::npos) { - throw NotImplementedException("Invalid struct child type specifier: %s", child_str); - } - auto child_name = child_str.substr(0, type_sep); - auto child_type = UCUtils::TypeToLogicalType(context, child_str.substr(type_sep+1, string::npos)); - children.push_back({child_name, child_type}); - if (next_sep == type_end) { - break; - } - cur = next_sep+1; - } - return LogicalType::STRUCT(children); - } - } + return LogicalType::TIMESTAMP; + } else if (type_text == "binary") { + return LogicalType::BLOB; + } else if (type_text == "date") { + return LogicalType::DATE; + } else if (type_text == "timestamp") { + return LogicalType::TIMESTAMP; // TODO: Is this the right timestamp + } else if (type_text.find("decimal(") == 0) { + size_t spec_end = type_text.find(')'); + if (spec_end != string::npos) { + size_t sep = type_text.find(','); + auto prec_str = type_text.substr(8, sep - 8); + auto scale_str = type_text.substr(sep + 1, spec_end - sep - 1); + uint8_t prec = Cast::Operation(prec_str); + uint8_t scale = Cast::Operation(scale_str); + return LogicalType::DECIMAL(prec, scale); + } + } else if (type_text.find("array<") == 0) { + size_t type_end = type_text.rfind('>'); // find last, to deal with nested + if (type_end != string::npos) { + auto child_type_str = type_text.substr(6, type_end - 6); + auto child_type = UCUtils::TypeToLogicalType(context, child_type_str); + return LogicalType::LIST(child_type); + } + } else if (type_text.find("map<") == 0) { + size_t type_end = type_text.rfind('>'); // find last, to deal with nested + if (type_end != string::npos) { + // TODO: Factor this and struct parsing into an iterator over ',' separated values + vector key_val; + size_t cur = 4; + auto nested_opens = 0; + for (;;) { + size_t next_sep = cur; + // find the location of the next ',' ignoring nested commas + while (type_text[next_sep] != ',' || nested_opens > 0) { + if (type_text[next_sep] == '<') { + nested_opens++; + } else if (type_text[next_sep] == '>') { + nested_opens--; + } + next_sep++; + if (next_sep == type_end) { + break; + } + } + auto child_str = type_text.substr(cur, next_sep - cur); + auto child_type = UCUtils::TypeToLogicalType(context, child_str); + key_val.push_back(child_type); + if (next_sep == type_end) { + break; + } + cur = next_sep + 1; + } + if (key_val.size() != 2) { + throw NotImplementedException("Invalid map specification with %i types", key_val.size()); + } + return LogicalType::MAP(key_val[0], key_val[1]); + } + } else if (type_text.find("struct<") == 0) { + size_t type_end = type_text.rfind('>'); // find last, to deal with nested + if (type_end != string::npos) { + child_list_t children; + size_t cur = 7; + auto nested_opens = 0; + for (;;) { + size_t next_sep = cur; + // find the location of the next ',' ignoring nested commas + while (type_text[next_sep] != ',' || nested_opens > 0) { + if (type_text[next_sep] == '<') { + nested_opens++; + } else if (type_text[next_sep] == '>') { + nested_opens--; + } + next_sep++; + if (next_sep == type_end) { + break; + } + } + auto child_str = type_text.substr(cur, next_sep - cur); + size_t type_sep = child_str.find(':'); + if (type_sep == string::npos) { + throw NotImplementedException("Invalid struct child type specifier: %s", child_str); + } + auto child_name = child_str.substr(0, type_sep); + auto child_type = UCUtils::TypeToLogicalType(context, child_str.substr(type_sep + 1, string::npos)); + children.push_back({child_name, child_type}); + if (next_sep == type_end) { + break; + } + cur = next_sep + 1; + } + return LogicalType::STRUCT(children); + } + } - throw NotImplementedException("Tried to fallback to unknown type for '%s'", type_text); + throw NotImplementedException("Tried to fallback to unknown type for '%s'", type_text); // fallback for unknown types return LogicalType::VARCHAR; } LogicalType UCUtils::ToUCType(const LogicalType &input) { - //todo do we need this mapping? - throw NotImplementedException("ToUCType not yet implemented"); + // todo do we need this mapping? + throw NotImplementedException("ToUCType not yet implemented"); switch (input.id()) { case LogicalTypeId::BOOLEAN: case LogicalTypeId::SMALLINT: