Skip to content

Commit

Permalink
Add llm_max & llm_min Aggregate Functions for Selecting Most/Lea…
Browse files Browse the repository at this point in the history
…st Relevant Row by group (#43)
  • Loading branch information
dorbanianas authored Oct 29, 2024
1 parent f335913 commit d12884d
Show file tree
Hide file tree
Showing 13 changed files with 435 additions and 2 deletions.
1 change: 1 addition & 0 deletions src/core/functions/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_subdirectory(scalar)
add_subdirectory(aggregate)

set(EXTENSION_SOURCES
${EXTENSION_SOURCES} ${CMAKE_CURRENT_SOURCE_DIR}/prompt_builder.cpp
Expand Down
5 changes: 5 additions & 0 deletions src/core/functions/aggregate/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
set(EXTENSION_SOURCES
${EXTENSION_SOURCES} ${CMAKE_CURRENT_SOURCE_DIR}/llm_max.cpp
${CMAKE_CURRENT_SOURCE_DIR}/llm_min.cpp
${CMAKE_CURRENT_SOURCE_DIR}/llm_min_or_max.cpp
PARENT_SCOPE)
20 changes: 20 additions & 0 deletions src/core/functions/aggregate/llm_max.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#include <inja/inja.hpp>
#include <nlohmann/json.hpp>
#include <flockmtl/common.hpp>
#include <flockmtl/core/functions/aggregate.hpp>
#include <flockmtl/core/functions/aggregate/llm_min_or_max.hpp>

namespace flockmtl {
namespace core {

void CoreAggregateFunctions::RegisterLlmMaxFunction(DatabaseInstance &db) {
auto string_concat = AggregateFunction(
"llm_max", {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::ANY}, LogicalType::JSON(),
AggregateFunction::StateSize<LlmMinOrMaxState>, LlmMinOrMaxOperation::Initialize, LlmMinOrMaxOperation::Operation,
LlmMinOrMaxOperation::Combine, LlmMinOrMaxOperation::Finalize<MinOrMax::MAX>, LlmMinOrMaxOperation::SimpleUpdate);

ExtensionUtil::RegisterFunction(db, string_concat);
}

} // namespace core
} // namespace flockmtl
20 changes: 20 additions & 0 deletions src/core/functions/aggregate/llm_min.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#include <inja/inja.hpp>
#include <nlohmann/json.hpp>
#include <flockmtl/common.hpp>
#include <flockmtl/core/functions/aggregate.hpp>
#include <flockmtl/core/functions/aggregate/llm_min_or_max.hpp>

namespace flockmtl {
namespace core {

void CoreAggregateFunctions::RegisterLlmMinFunction(DatabaseInstance &db) {
auto string_concat = AggregateFunction(
"llm_min", {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::ANY}, LogicalType::JSON(),
AggregateFunction::StateSize<LlmMinOrMaxState>, LlmMinOrMaxOperation::Initialize, LlmMinOrMaxOperation::Operation,
LlmMinOrMaxOperation::Combine, LlmMinOrMaxOperation::Finalize<MinOrMax::MIN>, LlmMinOrMaxOperation::SimpleUpdate);

ExtensionUtil::RegisterFunction(db, string_concat);
}

} // namespace core
} // namespace flockmtl
215 changes: 215 additions & 0 deletions src/core/functions/aggregate/llm_min_or_max.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
#include <flockmtl/core/functions/aggregate/llm_min_or_max.hpp>
#include <flockmtl/core/functions/prompt_builder.hpp>
#include "flockmtl/core/module.hpp"
#include "flockmtl/core/model_manager/model_manager.hpp"
#include "flockmtl/core/model_manager/tiktoken.hpp"

namespace flockmtl {
namespace core {

void LlmMinOrMaxState::Initialize() {
}

void LlmMinOrMaxState::Update(const nlohmann::json &input) {
value.push_back(input);
}

void LlmMinOrMaxState::Combine(const LlmMinOrMaxState &source) {
value = std::move(source.value);
}

LlmMinOrMax::LlmMinOrMax(std::string &model, int model_context_size, std::string &user_prompt,
std::string &llm_min_or_max_template)
: model(model), model_context_size(model_context_size), user_prompt(user_prompt),
llm_min_or_max_template(llm_min_or_max_template) {

auto num_tokens_meta_and_user_prompts = calculateFixedTokens();

if (num_tokens_meta_and_user_prompts > model_context_size) {
throw std::runtime_error("Fixed tokens exceed model context size");
}

available_tokens = model_context_size - num_tokens_meta_and_user_prompts;
}

int LlmMinOrMax::calculateFixedTokens() const {
int num_tokens_meta_and_user_prompts = 0;
num_tokens_meta_and_user_prompts += Tiktoken::GetNumTokens(user_prompt);
num_tokens_meta_and_user_prompts += Tiktoken::GetNumTokens(llm_min_or_max_template);
return num_tokens_meta_and_user_prompts;
}

nlohmann::json LlmMinOrMax::GetMaxOrMinTupleId(const nlohmann::json &tuples) {
inja::Environment env;
nlohmann::json data;
data["tuples"] = tuples;
data["user_prompt"] = user_prompt;
auto prompt = env.render(llm_min_or_max_template, data);

nlohmann::json settings;
auto response = ModelManager::CallComplete(prompt, model, settings);
return response["selected"];
}

nlohmann::json LlmMinOrMax::Evaluate(nlohmann::json &tuples) {

while (tuples.size() > 1) {
auto num_tuples = tuples.size();
uint32_t num_tuples_per_batch[num_tuples];

Check warning on line 58 in src/core/functions/aggregate/llm_min_or_max.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / DuckDB-Wasm (wasm_mvp, wasm32-emscripten)

variable length arrays in C++ are a Clang extension [-Wvla-cxx-extension]

Check warning on line 58 in src/core/functions/aggregate/llm_min_or_max.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / DuckDB-Wasm (wasm_mvp, wasm32-emscripten)

variable length arrays in C++ are a Clang extension [-Wvla-cxx-extension]

Check warning on line 58 in src/core/functions/aggregate/llm_min_or_max.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / DuckDB-Wasm (wasm_threads, wasm32-emscripten)

variable length arrays in C++ are a Clang extension [-Wvla-cxx-extension]

Check warning on line 58 in src/core/functions/aggregate/llm_min_or_max.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / DuckDB-Wasm (wasm_threads, wasm32-emscripten)

variable length arrays in C++ are a Clang extension [-Wvla-cxx-extension]

Check warning on line 58 in src/core/functions/aggregate/llm_min_or_max.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / DuckDB-Wasm (wasm_eh, wasm32-emscripten)

variable length arrays in C++ are a Clang extension [-Wvla-cxx-extension]

Check warning on line 58 in src/core/functions/aggregate/llm_min_or_max.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / DuckDB-Wasm (wasm_eh, wasm32-emscripten)

variable length arrays in C++ are a Clang extension [-Wvla-cxx-extension]

Check warning on line 58 in src/core/functions/aggregate/llm_min_or_max.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / DuckDB-Wasm (wasm_eh, wasm32-emscripten)

variable length arrays in C++ are a Clang extension [-Wvla-cxx-extension]

Check warning on line 58 in src/core/functions/aggregate/llm_min_or_max.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / DuckDB-Wasm (wasm_eh, wasm32-emscripten)

variable length arrays in C++ are a Clang extension [-Wvla-cxx-extension]

Check warning on line 58 in src/core/functions/aggregate/llm_min_or_max.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / DuckDB-Wasm (wasm_threads, wasm32-emscripten)

variable length arrays in C++ are a Clang extension [-Wvla-cxx-extension]

Check warning on line 58 in src/core/functions/aggregate/llm_min_or_max.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / DuckDB-Wasm (wasm_threads, wasm32-emscripten)

variable length arrays in C++ are a Clang extension [-Wvla-cxx-extension]

Check warning on line 58 in src/core/functions/aggregate/llm_min_or_max.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / DuckDB-Wasm (wasm_mvp, wasm32-emscripten)

variable length arrays in C++ are a Clang extension [-Wvla-cxx-extension]

Check warning on line 58 in src/core/functions/aggregate/llm_min_or_max.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / DuckDB-Wasm (wasm_mvp, wasm32-emscripten)

variable length arrays in C++ are a Clang extension [-Wvla-cxx-extension]
auto num_used_tokens = 0u;
auto batch_size = 0;
auto batch_index = 0;
for (int i = 0; i < num_tuples; i++) {
num_used_tokens += Tiktoken::GetNumTokens(tuples[i].dump());
batch_size++;

if (num_used_tokens >= available_tokens) {
num_tuples_per_batch[batch_index++] = batch_size;
num_used_tokens = 0;
batch_size = 0;
} else if (i == num_tuples - 1) {
num_tuples_per_batch[batch_index++] = batch_size;
}
}

auto responses = nlohmann::json::array();
auto num_batches = batch_index;

for (auto i = 0; i < num_batches; i++) {
auto start_index = i * num_tuples_per_batch[i];
auto end_index = start_index + num_tuples_per_batch[i];
auto batch = nlohmann::json::array();

for (auto j = start_index; j < end_index; j++) {
batch.push_back(tuples[j]);
}

auto ranked_indices = GetMaxOrMinTupleId(batch);
responses.push_back(batch[ranked_indices.get<int>()]);
}

tuples = responses;
};

return tuples[0]["content"];
}

// Static member initialization
std::string LlmMinOrMaxOperation::model_name;
std::string LlmMinOrMaxOperation::prompt_name;
std::unordered_map<void *, std::shared_ptr<LlmMinOrMaxState>> LlmMinOrMaxOperation::state_map;


void LlmMinOrMaxOperation::Initialize(const AggregateFunction &, data_ptr_t state_p) {
auto state_ptr = reinterpret_cast<LlmMinOrMaxState *>(state_p);

if (state_map.find(state_ptr) == state_map.end()) {
auto state = std::make_shared<LlmMinOrMaxState>();
state->Initialize();
state_map[state_ptr] = state;
}
}

void LlmMinOrMaxOperation::Operation(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, Vector &states,
idx_t count) {
prompt_name = inputs[0].GetValue(0).ToString();
model_name = inputs[1].GetValue(0).ToString();

if (inputs[2].GetType().id() != LogicalTypeId::STRUCT) {
throw std::runtime_error("Expected a struct type for prompt inputs");
}
auto tuples = CastVectorOfStructsToJson(inputs[2], count);

auto states_vector = FlatVector::GetData<LlmMinOrMaxState *>(states);

for (idx_t i = 0; i < count; i++) {
auto tuple = tuples[i];
auto state_ptr = states_vector[i];

auto state = state_map[state_ptr];
state->Update(tuple);
}
}

void LlmMinOrMaxOperation::Combine(Vector &source, Vector &target, AggregateInputData &aggr_input_data, idx_t count) {
auto source_vector = FlatVector::GetData<LlmMinOrMaxState *>(source);
auto target_vector = FlatVector::GetData<LlmMinOrMaxState *>(target);

for (auto i = 0; i < count; i++) {
auto source_ptr = source_vector[i];
auto target_ptr = target_vector[i];

auto source_state = state_map[source_ptr];
auto target_state = state_map[target_ptr];

target_state->Combine(*source_state);
}
}

void LlmMinOrMaxOperation::FinalizeResults(Vector &states, AggregateInputData &aggr_input_data, Vector &result, idx_t count,
idx_t offset, string llm_prompt_template) {
auto states_vector = FlatVector::GetData<LlmMinOrMaxState *>(states);

for (idx_t i = 0; i < count; i++) {
auto idx = i + offset;
auto state_ptr = states_vector[idx];
auto state = state_map[state_ptr];

auto query_result = CoreModule::GetConnection().Query(
"SELECT model, max_tokens FROM flockmtl_config.FLOCKMTL_MODEL_INTERNAL_TABLE WHERE model_name = '" +
model_name + "'");

if (query_result->RowCount() == 0) {
throw std::runtime_error("Model not found");
}

auto model = query_result->GetValue(0, 0).ToString();
auto model_context_size = query_result->GetValue(1, 0).GetValue<int>();

auto tuples_with_ids = nlohmann::json::array();
for (auto i = 0; i < state->value.size(); i++) {
auto tuple_with_id = nlohmann::json::object();
tuple_with_id["id"] = i;
tuple_with_id["content"] = state->value[i];
tuples_with_ids.push_back(tuple_with_id);
}

LlmMinOrMax llm_min_or_max(model, model_context_size, prompt_name, llm_prompt_template);
auto response = llm_min_or_max.Evaluate(tuples_with_ids);
result.SetValue(idx, response.dump());
}
}

template <>
void LlmMinOrMaxOperation::Finalize<MinOrMax::MIN>(Vector &states, AggregateInputData &aggr_input_data, Vector &result, idx_t count,
idx_t offset) {
FinalizeResults(states, aggr_input_data, result, count, offset, GetMinOrMaxPromptTemplate<MinOrMax::MIN>());
};

template <>
void LlmMinOrMaxOperation::Finalize<MinOrMax::MAX>(Vector &states, AggregateInputData &aggr_input_data, Vector &result, idx_t count,
idx_t offset) {
FinalizeResults(states, aggr_input_data, result, count, offset, GetMinOrMaxPromptTemplate<MinOrMax::MAX>());
};

void LlmMinOrMaxOperation::SimpleUpdate(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count,
data_ptr_t state_p, idx_t count) {
prompt_name = inputs[0].GetValue(0).ToString();
model_name = inputs[1].GetValue(0).ToString();
auto tuples = CastVectorOfStructsToJson(inputs[2], count);

auto state_map_p = reinterpret_cast<LlmMinOrMaxState *>(state_p);

for (idx_t i = 0; i < count; i++) {
auto tuple = tuples[i];
auto state = state_map[state_map_p];
state->Update(tuple);
}
}

bool LlmMinOrMaxOperation::IgnoreNull() {
return true;
}

} // namespace core
} // namespace flockmtl
14 changes: 14 additions & 0 deletions src/core/functions/prompt_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,20 @@
namespace flockmtl {
namespace core {

std::vector<nlohmann::json> CastVectorOfStructsToJson(Vector &struct_vector, int size) {
vector<nlohmann::json> vector_json;
for (auto i = 0; i < size; i++) {
nlohmann::json json;
for (auto j = 0; j < StructType::GetChildCount(struct_vector.GetType()); j++) {
auto key = StructType::GetChildName(struct_vector.GetType(), j);
auto value = StructValue::GetChildren(struct_vector.GetValue(i))[j].ToString();
json[key] = value;
}
vector_json.push_back(json);
}
return vector_json;
}

nlohmann::json GetMaxLengthValues(const std::vector<nlohmann::json> &params) {
nlohmann::json attr_to_max_token_length;

Expand Down
8 changes: 6 additions & 2 deletions src/core/model_manager/model_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,12 @@ nlohmann::json ModelManager::CallComplete(const std::string &prompt, const std::
}

// Make a request to the OpenAI API
auto completion = openai::chat().create(request_payload);

nlohmann::json completion;
try {
completion = openai::chat().create(request_payload);
} catch (const std::exception &e) {
throw std::runtime_error("Error in making request to OpenAI API: " + std::string(e.what()));
}
// Check if the conversation was too long for the context window
if (completion["choices"][0]["finish_reason"] == "length") {
// Handle the error when the context window is too long
Expand Down
16 changes: 16 additions & 0 deletions src/core/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,29 @@

#include "flockmtl/common.hpp"
#include "flockmtl/core/config/config.hpp"
#include "flockmtl/core/functions/aggregate.hpp"
#include "flockmtl/core/functions/scalar.hpp"

namespace flockmtl {
namespace core {

Connection &CoreModule::GetConnection(DatabaseInstance *db) {
static Connection *con = nullptr;

if (!con) {
if (!db) {
throw std::runtime_error("Connection is not initialized. Provide DatabaseInstance on first call.");
}
con = new Connection(*db);
}

return *con;
}

void CoreModule::Register(DatabaseInstance &db) {
CoreModule::GetConnection(&db);
CoreScalarFunctions::Register(db);
CoreAggregateFunctions::Register(db);
Config::Configure(db);
}

Expand Down
21 changes: 21 additions & 0 deletions src/include/flockmtl/core/functions/aggregate.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#pragma once
#include "flockmtl/common.hpp"

namespace flockmtl {

namespace core {

struct CoreAggregateFunctions {
static void Register(DatabaseInstance &db) {
RegisterLlmMaxFunction(db);
RegisterLlmMinFunction(db);
}

private:
static void RegisterLlmMaxFunction(DatabaseInstance &db);
static void RegisterLlmMinFunction(DatabaseInstance &db);
};

} // namespace core

} // namespace flockmtl
Loading

0 comments on commit d12884d

Please sign in to comment.