diff --git a/docs/docs/aggregate-reduce-functions/aggregate-reduce-functions.md b/docs/docs/aggregate-reduce-functions/aggregate-reduce-functions.md index d49d27c4..50f59c89 100644 --- a/docs/docs/aggregate-reduce-functions/aggregate-reduce-functions.md +++ b/docs/docs/aggregate-reduce-functions/aggregate-reduce-functions.md @@ -10,15 +10,19 @@ FlockMTL offers several powerful aggregate / reduce functions: - **Example Use Cases**: Summarizing documents, aggregating product descriptions. -2. [**`llm_first`**](/docs/aggregate-reduce-functions/llm-first): Returns the most relevant item from a group based on a prompt. +2. [**`llm_reduce_json`**](/docs/aggregate-reduce-functions/llm-reduce): Aggregates multiple rows into a single JSON output using a language model, ideal for tasks like summarization or consolidating text across multiple features. + + - **Example Use Cases**: Extracting key insights and sentiment from reviews, generating a summary with multiple attributes like themes and tone from survey responses. + +3. [**`llm_first`**](/docs/aggregate-reduce-functions/llm-first): Returns the most relevant item from a group based on a prompt. - **Example Use Cases**: Selecting the top-ranked document, finding the most relevant product. -3. [**`llm_last`**](/docs/aggregate-reduce-functions/llm-last): Returns the least relevant item from a group based on a prompt. +4. [**`llm_last`**](/docs/aggregate-reduce-functions/llm-last): Returns the least relevant item from a group based on a prompt. - **Example Use Cases**: Finding the least relevant document, selecting the least important product. -4. [**`llm_rerank`**](/docs/aggregate-reduce-functions/llm-rerank): Reorders a list of rows based on relevance to a prompt using a sliding window mechanism. +5. [**`llm_rerank`**](/docs/aggregate-reduce-functions/llm-rerank): Reorders a list of rows based on relevance to a prompt using a sliding window mechanism. - **Example Use Cases**: Reranking search results, adjusting document or product rankings. ## 2. How Aggregate / Reduce Functions Work diff --git a/docs/docs/aggregate-reduce-functions/llm-reduce-json.md b/docs/docs/aggregate-reduce-functions/llm-reduce-json.md new file mode 100644 index 00000000..35bf200a --- /dev/null +++ b/docs/docs/aggregate-reduce-functions/llm-reduce-json.md @@ -0,0 +1,156 @@ +--- +title: llm_reduce_json +sidebar_position: 1 +--- + +# llm_reduce_json Aggregate Function + +The `llm_reduce_json` function in FlockMTL consolidates multiple rows of text-based results into a single JSON output. It is used in SQL queries with the `GROUP BY` clause to combine multiple values into a summary or reduced form. + +## 1. **Usage Examples** + +### 1.1. **Example without `GROUP BY`** + +Summarize all product descriptions into one single result: + +```sql +SELECT llm_reduce_json( + {'model_name': 'gpt-4'}, + {'prompt': 'Summarize the following product descriptions'}, + {'product_description': product_description} +) AS product_summary +FROM products; +``` + +**Description**: This example aggregates all product descriptions into one summary. The `llm_reduce_json` function processes the `product_description` column for each row, consolidating the values into a single summarized output. + +### 1.2. **Example with `GROUP BY`** + +Group the products by category and summarize their descriptions into one for each category: + +```sql +SELECT category, + llm_reduce_json( + {'model_name': 'gpt-4'}, + {'prompt': 'Summarize the following product descriptions'}, + {'product_description': product_description} + ) AS summarized_product_info +FROM products +GROUP BY category; +``` + +**Description**: This query groups the products by category (e.g., electronics, clothing) and summarizes all product descriptions within each category into a single consolidated summary. + +### 1.3. **Using a Named Prompt with `GROUP BY`** + +Leverage a reusable named prompt for summarization, grouped by category: + +```sql +SELECT category, + llm_reduce_json( + {'model_name': 'gpt-4', 'secret_name': 'azure_key'}, + {'prompt_name': 'summarizer', 'version': 1}, + {'product_description': product_description} + ) AS summarized_product_info +FROM products +GROUP BY category; +``` + +**Description**: This example uses a pre-configured named prompt (`summarizer`) with version `1` to summarize product descriptions. The results are grouped by category, with one summary per category. + +### 1.4. **Advanced Example with Multiple Columns and `GROUP BY`** + +Summarize product details by category, using both the product name and description: + +```sql +WITH product_info AS ( + SELECT category, product_name, product_description + FROM products + WHERE category = 'Electronics' +) +SELECT category, + llm_reduce_json( + {'model_name': 'gpt-4'}, + {'prompt': 'Summarize the following product details'}, + {'product_name': product_name, 'product_description': product_description} + ) AS detailed_summary +FROM product_info +GROUP BY category; +``` + +**Description**: In this advanced example, the query summarizes both the `product_name` and `product_description` columns for products in the "Electronics" category, generating a detailed summary for that category. + +## 2. **Input Parameters** + +### 2.1 **Model Configuration** + +- **Parameter**: `model_name` and `secret_name` + +#### 2.1.1 Model Selection + +- **Description**: Specifies the model used for text generation. +- **Example**: + ```sql + { 'model_name': 'gpt-4' } + ``` + +#### 2.1.2 Model Selection with Secret + +- **Description**: Specifies the model along with the secret name to be used for authentication when accessing the model. +- **Example**: + ```sql + { 'model_name': 'gpt-4', 'secret_name': 'your_secret_name' } + ``` + +### 2.2. **Prompt Configuration** + +Two types of prompts can be used: + +1. **Inline Prompt** + + - Directly provides the prompt in the query. + - **Example**: + ```sql + {'prompt': 'Summarize the following product descriptions'} + ``` + +2. **Named Prompt** + + - Refers to a pre-configured prompt by name. + - **Example**: + ```sql + {'prompt_name': 'summarizer'} + ``` + +3. **Named Prompt with Version** + - Refers to a specific version of a pre-configured prompt. + - **Example**: + ```sql + {'prompt_name': 'summarizer', 'version': 1} + ``` + +### 2.3. **Column Mappings (Optional)** + +- **Key**: Column mappings. +- **Purpose**: Maps table columns to prompt variables for input. +- **Example**: + ```sql + {'product_name': product_name, 'product_description': product_description} + ``` + +## 3. **Output** + +- **Type**: Text (string). +- **Behavior**: The function consolidates multiple rows of text into a single output, summarizing or combining the provided data according to the model's response to the prompt. + +**Output Example**: +For a query that aggregates product descriptions, the result could look like: + +- **Input Rows**: + + - `product_name`: _"Running Shoes"_ + - `product_name`: _"Wireless Headphones"_ + - `product_name`: _"Smart Watch"_ + +- **Output**: + `"A variety of products including running shoes, wireless headphones, and smart watches, each designed for comfort, convenience, and performance in their respective categories."` diff --git a/src/functions/aggregate/llm_reduce/implementation.cpp b/src/functions/aggregate/llm_reduce/implementation.cpp index 600f35d7..82792318 100644 --- a/src/functions/aggregate/llm_reduce/implementation.cpp +++ b/src/functions/aggregate/llm_reduce/implementation.cpp @@ -2,11 +2,10 @@ namespace flockmtl { -int LlmReduce::GetAvailableTokens() { +int LlmReduce::GetAvailableTokens(const AggregateFunctionType& function_type) { int num_tokens_meta_and_reduce_query = 0; num_tokens_meta_and_reduce_query += Tiktoken::GetNumTokens(user_query); - num_tokens_meta_and_reduce_query += - Tiktoken::GetNumTokens(PromptManager::GetTemplate(AggregateFunctionType::REDUCE)); + num_tokens_meta_and_reduce_query += Tiktoken::GetNumTokens(PromptManager::GetTemplate(function_type)); auto model_context_size = model.GetModelDetails().context_window; if (num_tokens_meta_and_reduce_query > model_context_size) { @@ -17,15 +16,16 @@ int LlmReduce::GetAvailableTokens() { return available_tokens; } -nlohmann::json LlmReduce::ReduceBatch(const nlohmann::json& tuples) { +nlohmann::json LlmReduce::ReduceBatch(const nlohmann::json& tuples, const AggregateFunctionType& function_type) { nlohmann::json data; - auto prompt = PromptManager::Render(user_query, tuples, AggregateFunctionType::REDUCE); + auto prompt = PromptManager::Render(user_query, tuples, function_type); auto response = model.CallComplete(prompt); return response["output"]; }; -nlohmann::json LlmReduce::ReduceLoop(const std::vector& tuples) { - auto available_tokens = GetAvailableTokens(); +nlohmann::json LlmReduce::ReduceLoop(const std::vector& tuples, + const AggregateFunctionType& function_type) { + auto available_tokens = GetAvailableTokens(function_type); auto accumulated_tuples_tokens = 0u; auto batch_tuples = nlohmann::json::array(); int start_index = 0; @@ -45,7 +45,7 @@ nlohmann::json LlmReduce::ReduceLoop(const std::vector& tuples) accumulated_tuples_tokens += num_tokens; start_index++; } - auto response = ReduceBatch(batch_tuples); + auto response = ReduceBatch(batch_tuples, function_type); batch_tuples.clear(); batch_tuples.push_back(response); accumulated_tuples_tokens = 0u; @@ -54,8 +54,9 @@ nlohmann::json LlmReduce::ReduceLoop(const std::vector& tuples) return batch_tuples[0]; } -void LlmReduce::Finalize(duckdb::Vector& states, duckdb::AggregateInputData& aggr_input_data, duckdb::Vector& result, - idx_t count, idx_t offset) { +void LlmReduce::FinalizeResults(duckdb::Vector& states, duckdb::AggregateInputData& aggr_input_data, + duckdb::Vector& result, idx_t count, idx_t offset, + const AggregateFunctionType function_type) { auto states_vector = duckdb::FlatVector::GetData(states); auto function_instance = AggregateFunctionBase::GetInstance(); @@ -64,7 +65,7 @@ void LlmReduce::Finalize(duckdb::Vector& states, duckdb::AggregateInputData& agg auto state_ptr = states_vector[idx]; auto state = function_instance->state_map[state_ptr]; - auto response = function_instance->ReduceLoop(state->value); + auto response = function_instance->ReduceLoop(state->value, function_type); result.SetValue(idx, response.dump()); } } diff --git a/src/functions/aggregate/llm_reduce/instantiations.cpp b/src/functions/aggregate/llm_reduce/instantiations.cpp index 86248d36..ab6d9858 100644 --- a/src/functions/aggregate/llm_reduce/instantiations.cpp +++ b/src/functions/aggregate/llm_reduce/instantiations.cpp @@ -10,5 +10,9 @@ template void AggregateFunctionBase::SimpleUpdate(duckdb::Vector[], d duckdb::data_ptr_t, idx_t); template void AggregateFunctionBase::Combine(duckdb::Vector&, duckdb::Vector&, duckdb::AggregateInputData&, idx_t); +template void LlmReduce::Finalize(duckdb::Vector&, duckdb::AggregateInputData&, + duckdb::Vector&, idx_t, idx_t); +template void LlmReduce::Finalize(duckdb::Vector&, duckdb::AggregateInputData&, + duckdb::Vector&, idx_t, idx_t); } // namespace flockmtl diff --git a/src/functions/aggregate/llm_reduce/registry.cpp b/src/functions/aggregate/llm_reduce/registry.cpp index 07e3e188..227fd273 100644 --- a/src/functions/aggregate/llm_reduce/registry.cpp +++ b/src/functions/aggregate/llm_reduce/registry.cpp @@ -7,7 +7,18 @@ void AggregateRegistry::RegisterLlmReduce(duckdb::DatabaseInstance& db) { auto string_concat = duckdb::AggregateFunction( "llm_reduce", {duckdb::LogicalType::ANY, duckdb::LogicalType::ANY, duckdb::LogicalType::ANY}, duckdb::LogicalType::VARCHAR, duckdb::AggregateFunction::StateSize, - LlmReduce::Initialize, LlmReduce::Operation, LlmReduce::Combine, LlmReduce::Finalize, LlmReduce::SimpleUpdate); + LlmReduce::Initialize, LlmReduce::Operation, LlmReduce::Combine, + LlmReduce::Finalize, LlmReduce::SimpleUpdate); + + duckdb::ExtensionUtil::RegisterFunction(db, string_concat); +} + +void AggregateRegistry::RegisterLlmReduceJson(duckdb::DatabaseInstance& db) { + auto string_concat = duckdb::AggregateFunction( + "llm_reduce_json", {duckdb::LogicalType::ANY, duckdb::LogicalType::ANY, duckdb::LogicalType::ANY}, + duckdb::LogicalType::VARCHAR, duckdb::AggregateFunction::StateSize, + LlmReduce::Initialize, LlmReduce::Operation, LlmReduce::Combine, + LlmReduce::Finalize, LlmReduce::SimpleUpdate); duckdb::ExtensionUtil::RegisterFunction(db, string_concat); } diff --git a/src/include/flockmtl/functions/aggregate/llm_reduce.hpp b/src/include/flockmtl/functions/aggregate/llm_reduce.hpp index a46dc43f..9f32a971 100644 --- a/src/include/flockmtl/functions/aggregate/llm_reduce.hpp +++ b/src/include/flockmtl/functions/aggregate/llm_reduce.hpp @@ -8,9 +8,9 @@ class LlmReduce : public AggregateFunctionBase { public: explicit LlmReduce() = default; - int GetAvailableTokens(); - nlohmann::json ReduceBatch(const nlohmann::json& tuples); - nlohmann::json ReduceLoop(const std::vector& tuples); + int GetAvailableTokens(const AggregateFunctionType& function_type); + nlohmann::json ReduceBatch(const nlohmann::json& tuples, const AggregateFunctionType& function_type); + nlohmann::json ReduceLoop(const std::vector& tuples, const AggregateFunctionType& function_type); public: static void Initialize(const duckdb::AggregateFunction& function, duckdb::data_ptr_t state_p) { @@ -28,8 +28,14 @@ class LlmReduce : public AggregateFunctionBase { idx_t count) { AggregateFunctionBase::Combine(source, target, aggr_input_data, count); } + static void FinalizeResults(duckdb::Vector& states, duckdb::AggregateInputData& aggr_input_data, + duckdb::Vector& result, idx_t count, idx_t offset, + const AggregateFunctionType function_type); + template static void Finalize(duckdb::Vector& states, duckdb::AggregateInputData& aggr_input_data, duckdb::Vector& result, - idx_t count, idx_t offset); + idx_t count, idx_t offset) { + FinalizeResults(states, aggr_input_data, result, count, offset, function_type); + }; }; } // namespace flockmtl diff --git a/src/include/flockmtl/prompt_manager/repository.hpp b/src/include/flockmtl/prompt_manager/repository.hpp index 501c1175..1af5459e 100644 --- a/src/include/flockmtl/prompt_manager/repository.hpp +++ b/src/include/flockmtl/prompt_manager/repository.hpp @@ -6,7 +6,7 @@ namespace flockmtl { enum class PromptSection { USER_PROMPT, TUPLES, RESPONSE_FORMAT, INSTRUCTIONS }; -enum class AggregateFunctionType { REDUCE, FIRST, LAST, RERANK }; +enum class AggregateFunctionType { REDUCE, REDUCE_JSON, FIRST, LAST, RERANK }; enum class ScalarFunctionType { COMPLETE_JSON, COMPLETE, FILTER }; @@ -54,6 +54,11 @@ class RESPONSE_FORMAT { "Return a single, coherent output that synthesizes the most relevant information from the tuples provided. " "Respond in the following JSON format. **Do not add explanations or additional words beyond the summarized " "content.**\n\nResponse Format:\n\n```json\n{\n \"output\": \n}\n```"; + static constexpr auto REDUCE_JSON = + "Return a single, coherent output that synthesizes the most relevant information from the tuples provided. " + "Respond in the following JSON format where the summarized content should be in JSON format that contains " + "the necessary columns for the answer. **Do not add explanations or additional words beyond the summarized " + "content.**\n\nResponse Format:\n\n```json\n{\n \"output\": {}\n}\n```"; static constexpr auto FIRST_OR_LAST = "Identify the {{RELEVANCE}} relevant tuple based on the provided user prompt from the list of tuples. Output " "the index of this single tuple as a JSON object in the following format:\n\n```json\n{\n \"selected\": " diff --git a/src/include/flockmtl/registry/aggregate.hpp b/src/include/flockmtl/registry/aggregate.hpp index 3382e108..a600cac7 100644 --- a/src/include/flockmtl/registry/aggregate.hpp +++ b/src/include/flockmtl/registry/aggregate.hpp @@ -13,6 +13,7 @@ class AggregateRegistry { static void RegisterLlmLast(duckdb::DatabaseInstance& db); static void RegisterLlmRerank(duckdb::DatabaseInstance& db); static void RegisterLlmReduce(duckdb::DatabaseInstance& db); + static void RegisterLlmReduceJson(duckdb::DatabaseInstance& db); }; } // namespace flockmtl diff --git a/src/prompt_manager/repository.cpp b/src/prompt_manager/repository.cpp index 83a40692..6532c52b 100644 --- a/src/prompt_manager/repository.cpp +++ b/src/prompt_manager/repository.cpp @@ -32,13 +32,15 @@ std::string RESPONSE_FORMAT::Get(const AggregateFunctionType option) { switch (option) { case AggregateFunctionType::REDUCE: return RESPONSE_FORMAT::REDUCE; + case AggregateFunctionType::REDUCE_JSON: + return RESPONSE_FORMAT::REDUCE_JSON; + case AggregateFunctionType::RERANK: + return RESPONSE_FORMAT::RERANK; case AggregateFunctionType::FIRST: case AggregateFunctionType::LAST: { return PromptManager::ReplaceSection(RESPONSE_FORMAT::FIRST_OR_LAST, "{{RELEVANCE}}", option == AggregateFunctionType::FIRST ? "most" : "least"); } - case AggregateFunctionType::RERANK: - return RESPONSE_FORMAT::RERANK; default: return ""; } diff --git a/src/registry/aggregate.cpp b/src/registry/aggregate.cpp index 588c4c8e..38bec2c0 100644 --- a/src/registry/aggregate.cpp +++ b/src/registry/aggregate.cpp @@ -7,6 +7,7 @@ void AggregateRegistry::Register(duckdb::DatabaseInstance& db) { RegisterLlmLast(db); RegisterLlmRerank(db); RegisterLlmReduce(db); + RegisterLlmReduceJson(db); } } // namespace flockmtl