diff --git a/src/functions/aggregate/llm_first_or_last/implementation.cpp b/src/functions/aggregate/llm_first_or_last/implementation.cpp index 61d6dd7b..2a29f80e 100644 --- a/src/functions/aggregate/llm_first_or_last/implementation.cpp +++ b/src/functions/aggregate/llm_first_or_last/implementation.cpp @@ -32,11 +32,13 @@ nlohmann::json LlmFirstOrLast::Evaluate(nlohmann::json& tuples) { do { accumulated_tuples_tokens = Tiktoken::GetNumTokens(batch_tuples.dump()); accumulated_tuples_tokens += - Tiktoken::GetNumTokens(PromptManager::ConstructMarkdownHeader(tuples[start_index])); + Tiktoken::GetNumTokens(PromptManager::ConstructNumTuples(static_cast(tuples.size()))); + accumulated_tuples_tokens += + Tiktoken::GetNumTokens(PromptManager::ConstructInputTuplesHeader(tuples[start_index])); while (accumulated_tuples_tokens < static_cast(available_tokens) && start_index < static_cast(tuples.size())) { const auto num_tokens = - Tiktoken::GetNumTokens(PromptManager::ConstructMarkdownSingleTuple(tuples[start_index])); + Tiktoken::GetNumTokens(PromptManager::ConstructSingleInputTuple(tuples[start_index])); if (accumulated_tuples_tokens + num_tokens > static_cast(available_tokens)) { break; } diff --git a/src/functions/aggregate/llm_reduce/implementation.cpp b/src/functions/aggregate/llm_reduce/implementation.cpp index 82792318..550259b1 100644 --- a/src/functions/aggregate/llm_reduce/implementation.cpp +++ b/src/functions/aggregate/llm_reduce/implementation.cpp @@ -33,11 +33,13 @@ nlohmann::json LlmReduce::ReduceLoop(const std::vector& tuples, do { accumulated_tuples_tokens = Tiktoken::GetNumTokens(batch_tuples.dump()); accumulated_tuples_tokens += - Tiktoken::GetNumTokens(PromptManager::ConstructMarkdownHeader(tuples[start_index])); + Tiktoken::GetNumTokens(PromptManager::ConstructNumTuples(static_cast(tuples.size()))); + accumulated_tuples_tokens += + Tiktoken::GetNumTokens(PromptManager::ConstructInputTuplesHeader(tuples[start_index])); while (accumulated_tuples_tokens < static_cast(available_tokens) && start_index < static_cast(tuples.size())) { const auto num_tokens = - Tiktoken::GetNumTokens(PromptManager::ConstructMarkdownSingleTuple(tuples[start_index])); + Tiktoken::GetNumTokens(PromptManager::ConstructSingleInputTuple(tuples[start_index])); if (accumulated_tuples_tokens + num_tokens > static_cast(available_tokens)) { break; } diff --git a/src/functions/aggregate/llm_rerank/implementation.cpp b/src/functions/aggregate/llm_rerank/implementation.cpp index bc152ad6..c2050bbf 100644 --- a/src/functions/aggregate/llm_rerank/implementation.cpp +++ b/src/functions/aggregate/llm_rerank/implementation.cpp @@ -40,9 +40,12 @@ nlohmann::json LlmRerank::SlidingWindow(nlohmann::json& tuples) { next_tuples.clear(); batch_size = half_batch; accumulated_rows_tokens = Tiktoken::GetNumTokens(window_tuples.dump()); - accumulated_rows_tokens += Tiktoken::GetNumTokens(PromptManager::ConstructMarkdownHeader(tuples[start_index])); + accumulated_rows_tokens += + Tiktoken::GetNumTokens(PromptManager::ConstructNumTuples(static_cast(tuples.size()))); + accumulated_rows_tokens += + Tiktoken::GetNumTokens(PromptManager::ConstructInputTuplesHeader(tuples[start_index])); while (available_tokens - accumulated_rows_tokens > 0 && start_index >= 0) { - auto num_tokens = Tiktoken::GetNumTokens(PromptManager::ConstructMarkdownSingleTuple(tuples[start_index])); + auto num_tokens = Tiktoken::GetNumTokens(PromptManager::ConstructSingleInputTuple(tuples[start_index])); if (accumulated_rows_tokens + num_tokens > static_cast(available_tokens)) { break; } diff --git a/src/functions/scalar/llm_embedding/implementation.cpp b/src/functions/scalar/llm_embedding/implementation.cpp index c12a6c09..99385093 100644 --- a/src/functions/scalar/llm_embedding/implementation.cpp +++ b/src/functions/scalar/llm_embedding/implementation.cpp @@ -30,14 +30,26 @@ std::vector> LlmEmbedding::Operation(duckdb::DataC prepared_inputs.push_back(concat_input); } - auto embeddings = model.CallEmbedding(prepared_inputs); + auto batch_size = model.GetModelDetails().batch_size; + + if (batch_size == 0 || batch_size > prepared_inputs.size()) { + batch_size = static_cast(prepared_inputs.size()); + } + std::vector> results; - for (size_t index = 0; index < embeddings.size(); index++) { - duckdb::vector embedding; - for (auto& value : embeddings[index]) { - embedding.push_back(duckdb::Value(static_cast(value))); + for (size_t i = 0; i < prepared_inputs.size(); i += batch_size) { + std::vector batch_inputs; + for (size_t j = i; j < i + batch_size && j < prepared_inputs.size(); j++) { + batch_inputs.push_back(prepared_inputs[j]); + } + auto embeddings = model.CallEmbedding(batch_inputs); + for (size_t index = 0; index < embeddings.size(); index++) { + duckdb::vector embedding; + for (auto& value : embeddings[index]) { + embedding.push_back(duckdb::Value(static_cast(value))); + } + results.push_back(embedding); } - results.push_back(embedding); } return results; } diff --git a/src/functions/scalar/llm_filter/implementation.cpp b/src/functions/scalar/llm_filter/implementation.cpp index 8d3da808..d1b19b7b 100644 --- a/src/functions/scalar/llm_filter/implementation.cpp +++ b/src/functions/scalar/llm_filter/implementation.cpp @@ -36,6 +36,10 @@ std::vector LlmFilter::Operation(duckdb::DataChunk& args) { std::vector results; results.reserve(responses.size()); for (const auto& response : responses) { + if (response.is_null()) { + results.emplace_back("True"); + continue; + } results.push_back(response.dump()); } diff --git a/src/functions/scalar/scalar.cpp b/src/functions/scalar/scalar.cpp index e47e42b9..fa868d17 100644 --- a/src/functions/scalar/scalar.cpp +++ b/src/functions/scalar/scalar.cpp @@ -18,56 +18,101 @@ nlohmann::json ScalarFunctionBase::BatchAndComplete(const std::vector(available_tokens) && - start_index < static_cast(tuples.size()) && batch_tuples.size() < batch_size) { - const auto num_tokens = - Tiktoken::GetNumTokens(PromptManager::ConstructMarkdownSingleTuple(tuples[start_index])); - if (accumulated_tuples_tokens + num_tokens > static_cast(available_tokens)) { - break; + if (batch_size == 0) { + auto accumulated_tuples_tokens = 0u; + batch_size = static_cast(tuples.size()); + do { + accumulated_tuples_tokens += + Tiktoken::GetNumTokens(PromptManager::ConstructNumTuples(static_cast(tuples.size()))); + accumulated_tuples_tokens += + Tiktoken::GetNumTokens(PromptManager::ConstructInputTuplesHeader(tuples[start_index])); + while (accumulated_tuples_tokens < static_cast(available_tokens) && + start_index < static_cast(tuples.size()) && batch_tuples.size() < batch_size) { + const auto num_tokens = + Tiktoken::GetNumTokens(PromptManager::ConstructSingleInputTuple(tuples[start_index])); + if (accumulated_tuples_tokens + num_tokens > static_cast(available_tokens)) { + break; + } + batch_tuples.push_back(tuples[start_index]); + accumulated_tuples_tokens += num_tokens; + start_index++; + } + + nlohmann::json response; + try { + response = Complete(batch_tuples, user_prompt, function_type, model); + } catch (const ExceededMaxOutputTokensError&) { + batch_tuples.clear(); + const auto new_batch_size = static_cast(batch_size / 10); + batch_size = batch_size == 1 ? new_batch_size == 0 : new_batch_size; + accumulated_tuples_tokens = 0; + start_index = 0; + continue; } - batch_tuples.push_back(tuples[start_index]); - accumulated_tuples_tokens += num_tokens; - start_index++; - } - - nlohmann::json response; - try { - response = Complete(batch_tuples, user_prompt, function_type, model); - } catch (const ExceededMaxOutputTokensError&) { + + if (response.size() < batch_tuples.size()) { + for (auto i = static_cast(response.size()); i < batch_tuples.size(); i++) { + response.push_back(nullptr); + } + } else if (response.size() > batch_size) { + auto new_response = nlohmann::json::array(); + for (auto i = 0; i < batch_size; i++) { + new_response.push_back(response[i]); + } + response = new_response; + } + + auto output_tokens_per_tuple = Tiktoken::GetNumTokens(response.dump()) / batch_tuples.size(); + + batch_size = model_details.max_output_tokens / static_cast(output_tokens_per_tuple); batch_tuples.clear(); - const auto new_batch_size = static_cast(batch_size * 0.1); - batch_size = batch_size == 1 ? new_batch_size == 0 : new_batch_size; accumulated_tuples_tokens = 0; - start_index = 0; - continue; - } - auto output_tokens_per_tuple = Tiktoken::GetNumTokens(response.dump()) / batch_tuples.size(); - batch_size = model.GetModelDetails().max_output_tokens / output_tokens_per_tuple; - batch_tuples.clear(); - accumulated_tuples_tokens = 0; + for (const auto& tuple : response) { + responses.push_back(tuple); + } + + } while (start_index < static_cast(tuples.size())); + } else { + do { + for (auto i = 0; i < batch_size; i++) { + if (start_index + i < static_cast(tuples.size())) { + batch_tuples.push_back(tuples[start_index + i]); + } + } + start_index += batch_size; - for (const auto& tuple : response) { - responses.push_back(tuple); - } + auto response = Complete(batch_tuples, user_prompt, function_type, model); - } while (start_index < static_cast(tuples.size())); + if (response.size() < batch_tuples.size()) { + for (auto i = static_cast(response.size()); i < batch_tuples.size(); i++) { + response.push_back(nullptr); + } + } else if (response.size() > batch_size) { + auto new_response = nlohmann::json::array(); + for (auto i = 0; i < batch_size; i++) { + new_response.push_back(response[i]); + } + response = new_response; + } + batch_tuples.clear(); + for (const auto& tuple : response) { + responses.push_back(tuple); + } + } while (start_index < static_cast(tuples.size())); + } } return responses; diff --git a/src/include/flockmtl/model_manager/repository.hpp b/src/include/flockmtl/model_manager/repository.hpp index e3569f59..d542e960 100644 --- a/src/include/flockmtl/model_manager/repository.hpp +++ b/src/include/flockmtl/model_manager/repository.hpp @@ -13,6 +13,7 @@ struct ModelDetails { int32_t max_output_tokens; float temperature; std::unordered_map secret; + int batch_size; }; const std::string OLLAMA = "ollama"; diff --git a/src/include/flockmtl/prompt_manager/prompt_manager.hpp b/src/include/flockmtl/prompt_manager/prompt_manager.hpp index a135d264..d9b85046 100644 --- a/src/include/flockmtl/prompt_manager/prompt_manager.hpp +++ b/src/include/flockmtl/prompt_manager/prompt_manager.hpp @@ -33,16 +33,18 @@ class PromptManager { static PromptDetails CreatePromptDetails(const nlohmann::json& prompt_details_json); - static std::string ConstructMarkdownHeader(const nlohmann::json& tuple); + static std::string ConstructNumTuples(int num_tuples); - static std::string ConstructMarkdownSingleTuple(const nlohmann::json& tuple); + static std::string ConstructInputTuplesHeader(const nlohmann::json& tuple); - static std::string ConstructMarkdownArrayTuples(const nlohmann::json& tuples); + static std::string ConstructSingleInputTuple(const nlohmann::json& tuple); + + static std::string ConstructInputTuples(const nlohmann::json& tuples); template static std::string Render(const std::string& user_prompt, const nlohmann::json& tuples, FunctionType option) { auto prompt = PromptManager::GetTemplate(option); - auto markdown_tuples = PromptManager::ConstructMarkdownArrayTuples(tuples); + auto markdown_tuples = PromptManager::ConstructInputTuples(tuples); prompt = PromptManager::ReplaceSection(prompt, PromptSection::USER_PROMPT, user_prompt); prompt = PromptManager::ReplaceSection(prompt, PromptSection::TUPLES, markdown_tuples); return prompt; diff --git a/src/include/flockmtl/prompt_manager/repository.hpp b/src/include/flockmtl/prompt_manager/repository.hpp index 1af5459e..4d2d7db3 100644 --- a/src/include/flockmtl/prompt_manager/repository.hpp +++ b/src/include/flockmtl/prompt_manager/repository.hpp @@ -11,15 +11,16 @@ enum class AggregateFunctionType { REDUCE, REDUCE_JSON, FIRST, LAST, RERANK }; enum class ScalarFunctionType { COMPLETE_JSON, COMPLETE, FILTER }; constexpr auto META_PROMPT = - "You are a semantic analysis tool for DBMS. The tool will analyze each tuple in the provided data and respond to " - "user requests based on this context.\n\nUser Prompt:\n\n- {{USER_PROMPT}}\n\nTuples " + "You are FlockMTL a semantic analysis tool for DBMS. You will analyze each tuple in the provided data and respond " + "to " + "the user prompt.\n\nUser Prompt:\n\n- {{USER_PROMPT}}\n\nTuples " "Table:\n\n{{TUPLES}}\n\nInstructions:\n\n{{INSTRUCTIONS}}\n\nExpected Response Format:\n\n{{RESPONSE_FORMAT}}"; class INSTRUCTIONS { public: static constexpr auto SCALAR_FUNCTION = "- The response should be directly relevant to each tuple without additional formatting, purely answering the " - "prompt as if each tuple were a standalone entity.\n- Use clear, context-relevant language to generate a " + "user prompt as if each tuple were a standalone entity.\n- Use clear, context-relevant language to generate a " "meaningful and concise answer for each tuple."; static constexpr auto AGGREGATE_FUNCTION = "- For each tuple in the provided data, evaluate the relevant attribute(s) based on the user prompt.\n- After " @@ -36,18 +37,18 @@ class RESPONSE_FORMAT { public: // Scalar Functions static constexpr auto COMPLETE_JSON = - "The system should interpret database tuples and provide a response to the user's prompt for each tuple in a " - "JSON format that contains the necessary columns for the answer.\n\nThe tool should respond in JSON format as " - "follows:\n\n```json\n{\t\"tuples\": [\n\t\t{},\n\t\t{},\n\t\t...\n\t\t{}\n\t]\n}\n```"; + "You should return the responses to the user's prompt for each tuple in a " + "JSON format that contains the necessary columns for the answer.\n\nThe tool should respond in JSON format:\n\n" + "```json\n{\"tuples\": [{},{}, ..., {}]}\n```"; static constexpr auto COMPLETE = - "The system should interpret database tuples and provide a response to the user's prompt for each tuple in " - "plain text.\n\tThe tool should respond in JSON format as follows:\n\n```json\n{\"tuples\": [\"\", " - "\"\", ... , \"\"]}"; + "You should return the responses to the user's prompt for each tuple in plain text. Ensure no tuple is " + "missed.\n" + "Respond in the following JSON format:\n\n" + "```json\n{\"tuples\": [\"\", \"\", ..., \"\"]}\n```"; static constexpr auto FILTER = - "The system should interpret database tuples and provide a response to the user's prompt for each tuple in a " + "You should return the responses to the user's prompt for each tuple in a " "BOOL format that would be true/false.\n\tThe tool should respond in JSON format as " - "follows:\n\n```json\n{\"tuples\": [, , ... , ]}"; + "follows:\n\n```json\n{\"tuples\": [, , ... , ]}\n```"; // Aggregate Functions static constexpr auto REDUCE = diff --git a/src/model_manager/model.cpp b/src/model_manager/model.cpp index 04daac5f..eda13f2c 100644 --- a/src/model_manager/model.cpp +++ b/src/model_manager/model.cpp @@ -26,11 +26,13 @@ void Model::LoadModelDetails(const nlohmann::json& model_json) { } model_details_.secret = SecretManager::GetSecret(secret_name); model_details_.context_window = - model_json.contains("context_window") ? model_json.at("context_window").get() : std::get<2>(query_result); + model_json.contains("context_window") ? std::stoi(model_json.at("context_window").get()) : std::get<2>(query_result); model_details_.max_output_tokens = model_json.contains("max_output_tokens") - ? model_json.at("max_output_tokens").get() + ? std::stoi(model_json.at("max_output_tokens").get()) : std::get<3>(query_result); - model_details_.temperature = model_json.contains("temperature") ? model_json.at("temperature").get() : 0.5; + model_details_.temperature = model_json.contains("temperature") ? model_json.at("temperature").get() : 0.7; + model_details_.temperature = model_json.contains("temperature") ? std::stof(model_json.at("temperature").get()) : 0; + model_details_.batch_size = model_json.contains("batch_size") ? std::stoi(model_json.at("batch_size").get()) : 0; } std::tuple Model::GetQueriedModel(const std::string& model_name) { diff --git a/src/prompt_manager/prompt_manager.cpp b/src/prompt_manager/prompt_manager.cpp index d1fa90b8..23f1e35a 100644 --- a/src/prompt_manager/prompt_manager.cpp +++ b/src/prompt_manager/prompt_manager.cpp @@ -38,35 +38,36 @@ std::string PromptManager::ReplaceSection(const std::string& prompt_template, co return prompt; } -std::string PromptManager::ConstructMarkdownHeader(const nlohmann::json& tuple) { - std::string header_markdown = "|"; +std::string PromptManager::ConstructInputTuplesHeader(const nlohmann::json& tuple) { + auto header = std::string("
"); for (const auto& key : tuple.items()) { - header_markdown += key.key() + " | "; + header += "" + key.key() + ""; } - header_markdown += "\n"; - for (auto i = 0; i < static_cast(tuple.size()); i++) { - header_markdown += "|---"; - } - header_markdown += "|\n"; - return header_markdown; + header += "
\n"; + return header; } -std::string PromptManager::ConstructMarkdownSingleTuple(const nlohmann::json& tuple) { - std::string tuple_markdown = "|"; +std::string PromptManager::ConstructSingleInputTuple(const nlohmann::json& tuple) { + auto tuple_str = std::string(""); for (const auto& key : tuple.items()) { - tuple_markdown += key.value().dump() + " | "; + tuple_str += "" + key.value().dump() + ""; } - tuple_markdown += "\n"; - return tuple_markdown; + tuple_str += "\n"; + return tuple_str; +} + +std::string PromptManager::ConstructNumTuples(const int num_tuples) { + return "- The Number of Tuples to Generate Responses for: " + std::to_string(num_tuples) + "\n\n"; } -std::string PromptManager::ConstructMarkdownArrayTuples(const nlohmann::json& tuples) { - std::string tuples_markdown = ""; - tuples_markdown += PromptManager::ConstructMarkdownHeader(tuples[0]); +std::string PromptManager::ConstructInputTuples(const nlohmann::json& tuples) { + auto tuples_str = std::string(""); + tuples_str += PromptManager::ConstructNumTuples(static_cast(tuples.size())); + tuples_str += PromptManager::ConstructInputTuplesHeader(tuples[0]); for (const auto& tuple : tuples) { - tuples_markdown += PromptManager::ConstructMarkdownSingleTuple(tuple); + tuples_str += PromptManager::ConstructSingleInputTuple(tuple); } - return tuples_markdown; + return tuples_str; } PromptDetails PromptManager::CreatePromptDetails(const nlohmann::json& prompt_details_json) {