From 380bc738c6049921e35eae771c35a554059aadc1 Mon Sep 17 00:00:00 2001 From: Anas Dorbani Date: Sun, 12 Jan 2025 13:13:58 +0100 Subject: [PATCH 1/3] added the tuple_format to the model and prompt managers' headers --- .../flockmtl/model_manager/repository.hpp | 1 + .../flockmtl/prompt_manager/prompt_manager.hpp | 16 +++++++++++----- .../flockmtl/prompt_manager/repository.hpp | 6 ++++++ 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/src/include/flockmtl/model_manager/repository.hpp b/src/include/flockmtl/model_manager/repository.hpp index d542e960..3d2a1d47 100644 --- a/src/include/flockmtl/model_manager/repository.hpp +++ b/src/include/flockmtl/model_manager/repository.hpp @@ -13,6 +13,7 @@ struct ModelDetails { int32_t max_output_tokens; float temperature; std::unordered_map secret; + std::string tuple_format; int batch_size; }; diff --git a/src/include/flockmtl/prompt_manager/prompt_manager.hpp b/src/include/flockmtl/prompt_manager/prompt_manager.hpp index d9b85046..0fd45107 100644 --- a/src/include/flockmtl/prompt_manager/prompt_manager.hpp +++ b/src/include/flockmtl/prompt_manager/prompt_manager.hpp @@ -35,16 +35,22 @@ class PromptManager { static std::string ConstructNumTuples(int num_tuples); - static std::string ConstructInputTuplesHeader(const nlohmann::json& tuple); + static std::string ConstructInputTuplesHeader(const nlohmann::json& tuple, const std::string& tuple_format = "XML"); + static std::string ConstructInputTuplesHeaderXML(const nlohmann::json& tuple); + static std::string ConstructInputTuplesHeaderMarkdown(const nlohmann::json& tuple); - static std::string ConstructSingleInputTuple(const nlohmann::json& tuple); + static std::string ConstructSingleInputTuple(const nlohmann::json& tuple, const std::string& tuple_format = "XML"); + static std::string ConstructSingleInputTupleXML(const nlohmann::json& tuple); + static std::string ConstructSingleInputTupleMarkdown(const nlohmann::json& tuple); + static std::string ConstructSingleInputTupleJSON(const nlohmann::json& tuple); - static std::string ConstructInputTuples(const nlohmann::json& tuples); + static std::string ConstructInputTuples(const nlohmann::json& tuples, const std::string& tuple_format = "XML"); template - static std::string Render(const std::string& user_prompt, const nlohmann::json& tuples, FunctionType option) { + static std::string Render(const std::string& user_prompt, const nlohmann::json& tuples, FunctionType option, + const std::string& tuple_format = "XML") { auto prompt = PromptManager::GetTemplate(option); - auto markdown_tuples = PromptManager::ConstructInputTuples(tuples); + auto markdown_tuples = PromptManager::ConstructInputTuples(tuples, tuple_format); prompt = PromptManager::ReplaceSection(prompt, PromptSection::USER_PROMPT, user_prompt); prompt = PromptManager::ReplaceSection(prompt, PromptSection::TUPLES, markdown_tuples); return prompt; diff --git a/src/include/flockmtl/prompt_manager/repository.hpp b/src/include/flockmtl/prompt_manager/repository.hpp index 4d2d7db3..e84f45b7 100644 --- a/src/include/flockmtl/prompt_manager/repository.hpp +++ b/src/include/flockmtl/prompt_manager/repository.hpp @@ -1,6 +1,7 @@ #pragma once #include +#include namespace flockmtl { @@ -10,6 +11,11 @@ enum class AggregateFunctionType { REDUCE, REDUCE_JSON, FIRST, LAST, RERANK }; enum class ScalarFunctionType { COMPLETE_JSON, COMPLETE, FILTER }; +enum class TupleFormat { XML, JSON, Markdown }; + +inline std::pmr::unordered_map TUPLE_FORMAT = { + {"XML", TupleFormat::XML}, {"JSON", TupleFormat::JSON}, {"Markdown", TupleFormat::Markdown}}; + constexpr auto META_PROMPT = "You are FlockMTL a semantic analysis tool for DBMS. You will analyze each tuple in the provided data and respond " "to " From e6e8beea00a0237c4847ae81d5db4452b867da04 Mon Sep 17 00:00:00 2001 From: Anas Dorbani Date: Sun, 12 Jan 2025 13:15:57 +0100 Subject: [PATCH 2/3] implement the tuple format to be supported within the scalar functions --- src/functions/scalar/scalar.cpp | 2 +- src/model_manager/model.cpp | 14 +++--- src/prompt_manager/prompt_manager.cpp | 64 ++++++++++++++++++++++++--- 3 files changed, 67 insertions(+), 13 deletions(-) diff --git a/src/functions/scalar/scalar.cpp b/src/functions/scalar/scalar.cpp index fa868d17..2dc18189 100644 --- a/src/functions/scalar/scalar.cpp +++ b/src/functions/scalar/scalar.cpp @@ -5,7 +5,7 @@ namespace flockmtl { nlohmann::json ScalarFunctionBase::Complete(const nlohmann::json& tuples, const std::string& user_prompt, ScalarFunctionType function_type, Model& model) { nlohmann::json data; - const auto prompt = PromptManager::Render(user_prompt, tuples, function_type); + const auto prompt = PromptManager::Render(user_prompt, tuples, function_type, model.GetModelDetails().tuple_format); auto response = model.CallComplete(prompt); return response["tuples"]; }; diff --git a/src/model_manager/model.cpp b/src/model_manager/model.cpp index eda13f2c..f02b8904 100644 --- a/src/model_manager/model.cpp +++ b/src/model_manager/model.cpp @@ -25,14 +25,18 @@ void Model::LoadModelDetails(const nlohmann::json& model_json) { secret_name = model_json["secret_name"].get(); } model_details_.secret = SecretManager::GetSecret(secret_name); - model_details_.context_window = - model_json.contains("context_window") ? std::stoi(model_json.at("context_window").get()) : std::get<2>(query_result); + model_details_.context_window = model_json.contains("context_window") + ? std::stoi(model_json.at("context_window").get()) + : std::get<2>(query_result); model_details_.max_output_tokens = model_json.contains("max_output_tokens") ? std::stoi(model_json.at("max_output_tokens").get()) : std::get<3>(query_result); - model_details_.temperature = model_json.contains("temperature") ? model_json.at("temperature").get() : 0.7; - model_details_.temperature = model_json.contains("temperature") ? std::stof(model_json.at("temperature").get()) : 0; - model_details_.batch_size = model_json.contains("batch_size") ? std::stoi(model_json.at("batch_size").get()) : 0; + model_details_.temperature = + model_json.contains("temperature") ? std::stof(model_json.at("temperature").get()) : 0; + model_details_.tuple_format = + model_json.contains("tuple_format") ? model_json.at("tuple_format").get() : "XML"; + model_details_.batch_size = + model_json.contains("batch_size") ? std::stoi(model_json.at("batch_size").get()) : 0; } std::tuple Model::GetQueriedModel(const std::string& model_name) { diff --git a/src/prompt_manager/prompt_manager.cpp b/src/prompt_manager/prompt_manager.cpp index 23f1e35a..16a625bd 100644 --- a/src/prompt_manager/prompt_manager.cpp +++ b/src/prompt_manager/prompt_manager.cpp @@ -38,16 +38,55 @@ std::string PromptManager::ReplaceSection(const std::string& prompt_template, co return prompt; } -std::string PromptManager::ConstructInputTuplesHeader(const nlohmann::json& tuple) { - auto header = std::string("
"); +std::string PromptManager::ConstructInputTuplesHeader(const nlohmann::json& tuple, const std::string& tuple_format) { + switch (const auto format = TUPLE_FORMAT.at(tuple_format)) { + case TupleFormat::XML: + return ConstructInputTuplesHeaderXML(tuple); + case TupleFormat::Markdown: + return ConstructInputTuplesHeaderMarkdown(tuple); + case TupleFormat::JSON: + return ""; + default: + throw std::runtime_error("Invalid tuple format provided `" + tuple_format + "`"); + } +} + +std::string PromptManager::ConstructInputTuplesHeaderXML(const nlohmann::json& tuple) { + auto header = std::string(""); for (const auto& key : tuple.items()) { header += "" + key.key() + ""; } - header += "
\n"; + header += "\n"; + return header; +} + +std::string PromptManager::ConstructInputTuplesHeaderMarkdown(const nlohmann::json& tuple) { + auto header = std::string("|"); + for (const auto& key : tuple.items()) { + header += key.key() + "|"; + } + header += "\n|"; + for (const auto& key : tuple.items()) { + header += "---|"; + } + header += "\n"; return header; } -std::string PromptManager::ConstructSingleInputTuple(const nlohmann::json& tuple) { +std::string PromptManager::ConstructSingleInputTuple(const nlohmann::json& tuple, const std::string& tuple_format) { + switch (const auto format = TUPLE_FORMAT.at(tuple_format)) { + case TupleFormat::XML: + return ConstructSingleInputTupleXML(tuple); + case TupleFormat::Markdown: + return ConstructSingleInputTupleMarkdown(tuple); + case TupleFormat::JSON: + return ConstructSingleInputTupleJSON(tuple); + default: + throw std::runtime_error("Invalid tuple format provided `" + tuple_format + "`"); + } +} + +std::string PromptManager::ConstructSingleInputTupleXML(const nlohmann::json& tuple) { auto tuple_str = std::string(""); for (const auto& key : tuple.items()) { tuple_str += "" + key.value().dump() + ""; @@ -56,16 +95,27 @@ std::string PromptManager::ConstructSingleInputTuple(const nlohmann::json& tuple return tuple_str; } +std::string PromptManager::ConstructSingleInputTupleMarkdown(const nlohmann::json& tuple) { + auto tuple_str = std::string("|"); + for (const auto& key : tuple.items()) { + tuple_str += key.value().dump() + "|"; + } + tuple_str += "\n"; + return tuple_str; +} + +std::string PromptManager::ConstructSingleInputTupleJSON(const nlohmann::json& tuple) { return tuple.dump() + "\n"; } + std::string PromptManager::ConstructNumTuples(const int num_tuples) { return "- The Number of Tuples to Generate Responses for: " + std::to_string(num_tuples) + "\n\n"; } -std::string PromptManager::ConstructInputTuples(const nlohmann::json& tuples) { +std::string PromptManager::ConstructInputTuples(const nlohmann::json& tuples, const std::string& tuple_format) { auto tuples_str = std::string(""); tuples_str += PromptManager::ConstructNumTuples(static_cast(tuples.size())); - tuples_str += PromptManager::ConstructInputTuplesHeader(tuples[0]); + tuples_str += PromptManager::ConstructInputTuplesHeader(tuples[0], tuple_format); for (const auto& tuple : tuples) { - tuples_str += PromptManager::ConstructSingleInputTuple(tuple); + tuples_str += PromptManager::ConstructSingleInputTuple(tuple, tuple_format); } return tuples_str; } From b66464dfe5e283aee68f87665905cd419da6b3e8 Mon Sep 17 00:00:00 2001 From: Anas Dorbani Date: Sun, 12 Jan 2025 13:19:12 +0100 Subject: [PATCH 3/3] implement the tuple format to be supported within the aggregate functions --- src/functions/aggregate/llm_first_or_last/implementation.cpp | 2 +- src/functions/aggregate/llm_reduce/implementation.cpp | 2 +- src/functions/aggregate/llm_rerank/implementation.cpp | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/functions/aggregate/llm_first_or_last/implementation.cpp b/src/functions/aggregate/llm_first_or_last/implementation.cpp index 2a29f80e..afc6d89d 100644 --- a/src/functions/aggregate/llm_first_or_last/implementation.cpp +++ b/src/functions/aggregate/llm_first_or_last/implementation.cpp @@ -18,7 +18,7 @@ int LlmFirstOrLast::GetAvailableTokens() { int LlmFirstOrLast::GetFirstOrLastTupleId(const nlohmann::json& tuples) { nlohmann::json data; - auto prompt = PromptManager::Render(user_query, tuples, function_type); + const auto prompt = PromptManager::Render(user_query, tuples, function_type, model.GetModelDetails().tuple_format); auto response = model.CallComplete(prompt); return response["selected"].get(); } diff --git a/src/functions/aggregate/llm_reduce/implementation.cpp b/src/functions/aggregate/llm_reduce/implementation.cpp index 550259b1..e6921eb6 100644 --- a/src/functions/aggregate/llm_reduce/implementation.cpp +++ b/src/functions/aggregate/llm_reduce/implementation.cpp @@ -18,7 +18,7 @@ int LlmReduce::GetAvailableTokens(const AggregateFunctionType& function_type) { nlohmann::json LlmReduce::ReduceBatch(const nlohmann::json& tuples, const AggregateFunctionType& function_type) { nlohmann::json data; - auto prompt = PromptManager::Render(user_query, tuples, function_type); + const auto prompt = PromptManager::Render(user_query, tuples, function_type, model.GetModelDetails().tuple_format); auto response = model.CallComplete(prompt); return response["output"]; }; diff --git a/src/functions/aggregate/llm_rerank/implementation.cpp b/src/functions/aggregate/llm_rerank/implementation.cpp index c2050bbf..38bd279c 100644 --- a/src/functions/aggregate/llm_rerank/implementation.cpp +++ b/src/functions/aggregate/llm_rerank/implementation.cpp @@ -19,7 +19,8 @@ int LlmRerank::GetAvailableTokens() { std::vector LlmRerank::RerankBatch(const nlohmann::json& tuples) { nlohmann::json data; - auto prompt = PromptManager::Render(user_query, tuples, AggregateFunctionType::RERANK); + auto prompt = + PromptManager::Render(user_query, tuples, AggregateFunctionType::RERANK, model.GetModelDetails().tuple_format); auto response = model.CallComplete(prompt); return response["ranking"].get>(); };