Skip to content

Commit

Permalink
Add llm_reduce_json Aggregate Function (#114)
Browse files Browse the repository at this point in the history
  • Loading branch information
dorbanianas authored Dec 25, 2024
1 parent 475c34a commit 0e8990f
Show file tree
Hide file tree
Showing 10 changed files with 213 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,19 @@ FlockMTL offers several powerful aggregate / reduce functions:

- **Example Use Cases**: Summarizing documents, aggregating product descriptions.

2. [**`llm_first`**](/docs/aggregate-reduce-functions/llm-first): Returns the most relevant item from a group based on a prompt.
2. [**`llm_reduce_json`**](/docs/aggregate-reduce-functions/llm-reduce): Aggregates multiple rows into a single JSON output using a language model, ideal for tasks like summarization or consolidating text across multiple features.

- **Example Use Cases**: Extracting key insights and sentiment from reviews, generating a summary with multiple attributes like themes and tone from survey responses.

3. [**`llm_first`**](/docs/aggregate-reduce-functions/llm-first): Returns the most relevant item from a group based on a prompt.

- **Example Use Cases**: Selecting the top-ranked document, finding the most relevant product.

3. [**`llm_last`**](/docs/aggregate-reduce-functions/llm-last): Returns the least relevant item from a group based on a prompt.
4. [**`llm_last`**](/docs/aggregate-reduce-functions/llm-last): Returns the least relevant item from a group based on a prompt.

- **Example Use Cases**: Finding the least relevant document, selecting the least important product.

4. [**`llm_rerank`**](/docs/aggregate-reduce-functions/llm-rerank): Reorders a list of rows based on relevance to a prompt using a sliding window mechanism.
5. [**`llm_rerank`**](/docs/aggregate-reduce-functions/llm-rerank): Reorders a list of rows based on relevance to a prompt using a sliding window mechanism.
- **Example Use Cases**: Reranking search results, adjusting document or product rankings.

## 2. How Aggregate / Reduce Functions Work
Expand Down
156 changes: 156 additions & 0 deletions docs/docs/aggregate-reduce-functions/llm-reduce-json.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
---
title: llm_reduce_json
sidebar_position: 1
---

# llm_reduce_json Aggregate Function

The `llm_reduce_json` function in FlockMTL consolidates multiple rows of text-based results into a single JSON output. It is used in SQL queries with the `GROUP BY` clause to combine multiple values into a summary or reduced form.

## 1. **Usage Examples**

### 1.1. **Example without `GROUP BY`**

Summarize all product descriptions into one single result:

```sql
SELECT llm_reduce_json(
{'model_name': 'gpt-4'},
{'prompt': 'Summarize the following product descriptions'},
{'product_description': product_description}
) AS product_summary
FROM products;
```

**Description**: This example aggregates all product descriptions into one summary. The `llm_reduce_json` function processes the `product_description` column for each row, consolidating the values into a single summarized output.

### 1.2. **Example with `GROUP BY`**

Group the products by category and summarize their descriptions into one for each category:

```sql
SELECT category,
llm_reduce_json(
{'model_name': 'gpt-4'},
{'prompt': 'Summarize the following product descriptions'},
{'product_description': product_description}
) AS summarized_product_info
FROM products
GROUP BY category;
```

**Description**: This query groups the products by category (e.g., electronics, clothing) and summarizes all product descriptions within each category into a single consolidated summary.

### 1.3. **Using a Named Prompt with `GROUP BY`**

Leverage a reusable named prompt for summarization, grouped by category:

```sql
SELECT category,
llm_reduce_json(
{'model_name': 'gpt-4', 'secret_name': 'azure_key'},
{'prompt_name': 'summarizer', 'version': 1},
{'product_description': product_description}
) AS summarized_product_info
FROM products
GROUP BY category;
```

**Description**: This example uses a pre-configured named prompt (`summarizer`) with version `1` to summarize product descriptions. The results are grouped by category, with one summary per category.

### 1.4. **Advanced Example with Multiple Columns and `GROUP BY`**

Summarize product details by category, using both the product name and description:

```sql
WITH product_info AS (
SELECT category, product_name, product_description
FROM products
WHERE category = 'Electronics'
)
SELECT category,
llm_reduce_json(
{'model_name': 'gpt-4'},
{'prompt': 'Summarize the following product details'},
{'product_name': product_name, 'product_description': product_description}
) AS detailed_summary
FROM product_info
GROUP BY category;
```

**Description**: In this advanced example, the query summarizes both the `product_name` and `product_description` columns for products in the "Electronics" category, generating a detailed summary for that category.

## 2. **Input Parameters**

### 2.1 **Model Configuration**

- **Parameter**: `model_name` and `secret_name`

#### 2.1.1 Model Selection

- **Description**: Specifies the model used for text generation.
- **Example**:
```sql
{ 'model_name': 'gpt-4' }
```

#### 2.1.2 Model Selection with Secret

- **Description**: Specifies the model along with the secret name to be used for authentication when accessing the model.
- **Example**:
```sql
{ 'model_name': 'gpt-4', 'secret_name': 'your_secret_name' }
```

### 2.2. **Prompt Configuration**

Two types of prompts can be used:

1. **Inline Prompt**

- Directly provides the prompt in the query.
- **Example**:
```sql
{'prompt': 'Summarize the following product descriptions'}
```

2. **Named Prompt**

- Refers to a pre-configured prompt by name.
- **Example**:
```sql
{'prompt_name': 'summarizer'}
```

3. **Named Prompt with Version**
- Refers to a specific version of a pre-configured prompt.
- **Example**:
```sql
{'prompt_name': 'summarizer', 'version': 1}
```

### 2.3. **Column Mappings (Optional)**

- **Key**: Column mappings.
- **Purpose**: Maps table columns to prompt variables for input.
- **Example**:
```sql
{'product_name': product_name, 'product_description': product_description}
```

## 3. **Output**

- **Type**: Text (string).
- **Behavior**: The function consolidates multiple rows of text into a single output, summarizing or combining the provided data according to the model's response to the prompt.
**Output Example**:
For a query that aggregates product descriptions, the result could look like:
- **Input Rows**:
- `product_name`: _"Running Shoes"_
- `product_name`: _"Wireless Headphones"_
- `product_name`: _"Smart Watch"_
- **Output**:
`"A variety of products including running shoes, wireless headphones, and smart watches, each designed for comfort, convenience, and performance in their respective categories."`
23 changes: 12 additions & 11 deletions src/functions/aggregate/llm_reduce/implementation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

namespace flockmtl {

int LlmReduce::GetAvailableTokens() {
int LlmReduce::GetAvailableTokens(const AggregateFunctionType& function_type) {
int num_tokens_meta_and_reduce_query = 0;
num_tokens_meta_and_reduce_query += Tiktoken::GetNumTokens(user_query);
num_tokens_meta_and_reduce_query +=
Tiktoken::GetNumTokens(PromptManager::GetTemplate(AggregateFunctionType::REDUCE));
num_tokens_meta_and_reduce_query += Tiktoken::GetNumTokens(PromptManager::GetTemplate(function_type));

auto model_context_size = model.GetModelDetails().context_window;
if (num_tokens_meta_and_reduce_query > model_context_size) {
Expand All @@ -17,15 +16,16 @@ int LlmReduce::GetAvailableTokens() {
return available_tokens;
}

nlohmann::json LlmReduce::ReduceBatch(const nlohmann::json& tuples) {
nlohmann::json LlmReduce::ReduceBatch(const nlohmann::json& tuples, const AggregateFunctionType& function_type) {
nlohmann::json data;
auto prompt = PromptManager::Render(user_query, tuples, AggregateFunctionType::REDUCE);
auto prompt = PromptManager::Render(user_query, tuples, function_type);
auto response = model.CallComplete(prompt);
return response["output"];
};

nlohmann::json LlmReduce::ReduceLoop(const std::vector<nlohmann::json>& tuples) {
auto available_tokens = GetAvailableTokens();
nlohmann::json LlmReduce::ReduceLoop(const std::vector<nlohmann::json>& tuples,
const AggregateFunctionType& function_type) {
auto available_tokens = GetAvailableTokens(function_type);
auto accumulated_tuples_tokens = 0u;
auto batch_tuples = nlohmann::json::array();
int start_index = 0;
Expand All @@ -45,7 +45,7 @@ nlohmann::json LlmReduce::ReduceLoop(const std::vector<nlohmann::json>& tuples)
accumulated_tuples_tokens += num_tokens;
start_index++;
}
auto response = ReduceBatch(batch_tuples);
auto response = ReduceBatch(batch_tuples, function_type);
batch_tuples.clear();
batch_tuples.push_back(response);
accumulated_tuples_tokens = 0u;
Expand All @@ -54,8 +54,9 @@ nlohmann::json LlmReduce::ReduceLoop(const std::vector<nlohmann::json>& tuples)
return batch_tuples[0];
}

void LlmReduce::Finalize(duckdb::Vector& states, duckdb::AggregateInputData& aggr_input_data, duckdb::Vector& result,
idx_t count, idx_t offset) {
void LlmReduce::FinalizeResults(duckdb::Vector& states, duckdb::AggregateInputData& aggr_input_data,
duckdb::Vector& result, idx_t count, idx_t offset,
const AggregateFunctionType function_type) {
auto states_vector = duckdb::FlatVector::GetData<AggregateFunctionState*>(states);

auto function_instance = AggregateFunctionBase::GetInstance<LlmReduce>();
Expand All @@ -64,7 +65,7 @@ void LlmReduce::Finalize(duckdb::Vector& states, duckdb::AggregateInputData& agg
auto state_ptr = states_vector[idx];
auto state = function_instance->state_map[state_ptr];

auto response = function_instance->ReduceLoop(state->value);
auto response = function_instance->ReduceLoop(state->value, function_type);
result.SetValue(idx, response.dump());
}
}
Expand Down
4 changes: 4 additions & 0 deletions src/functions/aggregate/llm_reduce/instantiations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,9 @@ template void AggregateFunctionBase::SimpleUpdate<LlmReduce>(duckdb::Vector[], d
duckdb::data_ptr_t, idx_t);
template void AggregateFunctionBase::Combine<LlmReduce>(duckdb::Vector&, duckdb::Vector&, duckdb::AggregateInputData&,
idx_t);
template void LlmReduce::Finalize<AggregateFunctionType::REDUCE>(duckdb::Vector&, duckdb::AggregateInputData&,
duckdb::Vector&, idx_t, idx_t);
template void LlmReduce::Finalize<AggregateFunctionType::REDUCE_JSON>(duckdb::Vector&, duckdb::AggregateInputData&,
duckdb::Vector&, idx_t, idx_t);

} // namespace flockmtl
13 changes: 12 additions & 1 deletion src/functions/aggregate/llm_reduce/registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,18 @@ void AggregateRegistry::RegisterLlmReduce(duckdb::DatabaseInstance& db) {
auto string_concat = duckdb::AggregateFunction(
"llm_reduce", {duckdb::LogicalType::ANY, duckdb::LogicalType::ANY, duckdb::LogicalType::ANY},
duckdb::LogicalType::VARCHAR, duckdb::AggregateFunction::StateSize<AggregateFunctionState>,
LlmReduce::Initialize, LlmReduce::Operation, LlmReduce::Combine, LlmReduce::Finalize, LlmReduce::SimpleUpdate);
LlmReduce::Initialize, LlmReduce::Operation, LlmReduce::Combine,
LlmReduce::Finalize<AggregateFunctionType::REDUCE>, LlmReduce::SimpleUpdate);

duckdb::ExtensionUtil::RegisterFunction(db, string_concat);
}

void AggregateRegistry::RegisterLlmReduceJson(duckdb::DatabaseInstance& db) {
auto string_concat = duckdb::AggregateFunction(
"llm_reduce_json", {duckdb::LogicalType::ANY, duckdb::LogicalType::ANY, duckdb::LogicalType::ANY},
duckdb::LogicalType::VARCHAR, duckdb::AggregateFunction::StateSize<AggregateFunctionState>,
LlmReduce::Initialize, LlmReduce::Operation, LlmReduce::Combine,
LlmReduce::Finalize<AggregateFunctionType::REDUCE_JSON>, LlmReduce::SimpleUpdate);

duckdb::ExtensionUtil::RegisterFunction(db, string_concat);
}
Expand Down
14 changes: 10 additions & 4 deletions src/include/flockmtl/functions/aggregate/llm_reduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ class LlmReduce : public AggregateFunctionBase {
public:
explicit LlmReduce() = default;

int GetAvailableTokens();
nlohmann::json ReduceBatch(const nlohmann::json& tuples);
nlohmann::json ReduceLoop(const std::vector<nlohmann::json>& tuples);
int GetAvailableTokens(const AggregateFunctionType& function_type);
nlohmann::json ReduceBatch(const nlohmann::json& tuples, const AggregateFunctionType& function_type);
nlohmann::json ReduceLoop(const std::vector<nlohmann::json>& tuples, const AggregateFunctionType& function_type);

public:
static void Initialize(const duckdb::AggregateFunction& function, duckdb::data_ptr_t state_p) {
Expand All @@ -28,8 +28,14 @@ class LlmReduce : public AggregateFunctionBase {
idx_t count) {
AggregateFunctionBase::Combine<LlmReduce>(source, target, aggr_input_data, count);
}
static void FinalizeResults(duckdb::Vector& states, duckdb::AggregateInputData& aggr_input_data,
duckdb::Vector& result, idx_t count, idx_t offset,
const AggregateFunctionType function_type);
template <AggregateFunctionType function_type>
static void Finalize(duckdb::Vector& states, duckdb::AggregateInputData& aggr_input_data, duckdb::Vector& result,
idx_t count, idx_t offset);
idx_t count, idx_t offset) {
FinalizeResults(states, aggr_input_data, result, count, offset, function_type);
};
};

} // namespace flockmtl
7 changes: 6 additions & 1 deletion src/include/flockmtl/prompt_manager/repository.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace flockmtl {

enum class PromptSection { USER_PROMPT, TUPLES, RESPONSE_FORMAT, INSTRUCTIONS };

enum class AggregateFunctionType { REDUCE, FIRST, LAST, RERANK };
enum class AggregateFunctionType { REDUCE, REDUCE_JSON, FIRST, LAST, RERANK };

enum class ScalarFunctionType { COMPLETE_JSON, COMPLETE, FILTER };

Expand Down Expand Up @@ -54,6 +54,11 @@ class RESPONSE_FORMAT {
"Return a single, coherent output that synthesizes the most relevant information from the tuples provided. "
"Respond in the following JSON format. **Do not add explanations or additional words beyond the summarized "
"content.**\n\nResponse Format:\n\n```json\n{\n \"output\": <summarized content here>\n}\n```";
static constexpr auto REDUCE_JSON =
"Return a single, coherent output that synthesizes the most relevant information from the tuples provided. "
"Respond in the following JSON format where the summarized content should be in JSON format that contains "
"the necessary columns for the answer. **Do not add explanations or additional words beyond the summarized "
"content.**\n\nResponse Format:\n\n```json\n{\n \"output\": {<summarized content here>}\n}\n```";
static constexpr auto FIRST_OR_LAST =
"Identify the {{RELEVANCE}} relevant tuple based on the provided user prompt from the list of tuples. Output "
"the index of this single tuple as a JSON object in the following format:\n\n```json\n{\n \"selected\": "
Expand Down
1 change: 1 addition & 0 deletions src/include/flockmtl/registry/aggregate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class AggregateRegistry {
static void RegisterLlmLast(duckdb::DatabaseInstance& db);
static void RegisterLlmRerank(duckdb::DatabaseInstance& db);
static void RegisterLlmReduce(duckdb::DatabaseInstance& db);
static void RegisterLlmReduceJson(duckdb::DatabaseInstance& db);
};

} // namespace flockmtl
6 changes: 4 additions & 2 deletions src/prompt_manager/repository.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,15 @@ std::string RESPONSE_FORMAT::Get(const AggregateFunctionType option) {
switch (option) {
case AggregateFunctionType::REDUCE:
return RESPONSE_FORMAT::REDUCE;
case AggregateFunctionType::REDUCE_JSON:
return RESPONSE_FORMAT::REDUCE_JSON;
case AggregateFunctionType::RERANK:
return RESPONSE_FORMAT::RERANK;
case AggregateFunctionType::FIRST:
case AggregateFunctionType::LAST: {
return PromptManager::ReplaceSection(RESPONSE_FORMAT::FIRST_OR_LAST, "{{RELEVANCE}}",
option == AggregateFunctionType::FIRST ? "most" : "least");
}
case AggregateFunctionType::RERANK:
return RESPONSE_FORMAT::RERANK;
default:
return "";
}
Expand Down
1 change: 1 addition & 0 deletions src/registry/aggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ void AggregateRegistry::Register(duckdb::DatabaseInstance& db) {
RegisterLlmLast(db);
RegisterLlmRerank(db);
RegisterLlmReduce(db);
RegisterLlmReduceJson(db);
}

} // namespace flockmtl

0 comments on commit 0e8990f

Please sign in to comment.