From a9c9d67794f96c2ad76066af0ea741a9a1351eb2 Mon Sep 17 00:00:00 2001 From: Anas DORBANI <95044293+dorbanianas@users.noreply.github.com> Date: Sun, 8 Dec 2024 02:17:14 +0100 Subject: [PATCH] Udpate the azure type into azure_llm in secret manager (#88) --- .../secret_manager/secret_manager.hpp | 1 + src/model_manager/model.cpp | 2 ++ src/secret_manager/secret_manager.cpp | 19 +++++++------------ 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/src/include/flockmtl/secret_manager/secret_manager.hpp b/src/include/flockmtl/secret_manager/secret_manager.hpp index 3ac43e99..3d904f2d 100644 --- a/src/include/flockmtl/secret_manager/secret_manager.hpp +++ b/src/include/flockmtl/secret_manager/secret_manager.hpp @@ -22,6 +22,7 @@ extern const SecretDetails ollama_secret_details; class SecretManager { public: enum SupportedProviders { OPENAI, AZURE, OLLAMA }; + static std::unordered_map providerNames; static void Register(duckdb::DatabaseInstance& instance); static std::unordered_map GetSecret(std::string provider); diff --git a/src/model_manager/model.cpp b/src/model_manager/model.cpp index b990c108..ca4343e4 100644 --- a/src/model_manager/model.cpp +++ b/src/model_manager/model.cpp @@ -22,6 +22,8 @@ void Model::LoadModelDetails(const nlohmann::json& model_json) { model_details_.provider_name = model_json.contains("provider") ? model_json.at("provider").get() : std::get<1>(query_result); auto secret_name = "__default_" + model_details_.provider_name; + if (model_details_.provider_name == AZURE) + secret_name += "_llm"; if (model_json.contains("secret_name")) { secret_name = model_json["secret_name"].get(); } diff --git a/src/secret_manager/secret_manager.cpp b/src/secret_manager/secret_manager.cpp index f2e63812..1665e577 100644 --- a/src/secret_manager/secret_manager.cpp +++ b/src/secret_manager/secret_manager.cpp @@ -8,25 +8,20 @@ namespace flockmtl { const SecretDetails openai_secret_details = {"openai", "flockmtl", "openai://", {"api_key"}, {"api_key"}, {"api_key"}}; -const SecretDetails azure_secret_details = {"azure", "flockmtl", - "azure://", {"api_key", "resource_name", "api_version"}, - {"api_key"}, {"api_key", "resource_name", "api_version"}}; +const SecretDetails azure_secret_details = {"azure_llm", "flockmtl", + "azure_llm://", {"api_key", "resource_name", "api_version"}, + {"api_key"}, {"api_key", "resource_name", "api_version"}}; const SecretDetails ollama_secret_details = {"ollama", "flockmtl", "ollama://", {"api_url"}, {"api_url"}, {"api_url"}}; const std::vector secret_details_list = {&openai_secret_details, &azure_secret_details, &ollama_secret_details}; +std::unordered_map SecretManager::providerNames = { + {"openai", OPENAI}, {"azure", AZURE}, {"ollama", OLLAMA}}; + SecretManager::SupportedProviders SecretManager::GetProviderType(std::string provider) { - if (provider == "openai") { - return OPENAI; - } else if (provider == "azure") { - return AZURE; - } else if (provider == "ollama") { - return OLLAMA; - } else { - throw duckdb::InvalidInputException("Unsupported secret type: %s", provider.c_str()); - } + return providerNames[provider]; } void SecretManager::Register(duckdb::DatabaseInstance& instance) {