Skip to content

Commit

Permalink
remove redundant code (#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
dorbanianas authored Oct 26, 2024
1 parent 1bd8ac0 commit f335913
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 273 deletions.
2 changes: 1 addition & 1 deletion src/core/functions/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
add_subdirectory(scalar)

set(EXTENSION_SOURCES
${EXTENSION_SOURCES}
${EXTENSION_SOURCES} ${CMAKE_CURRENT_SOURCE_DIR}/prompt_builder.cpp
PARENT_SCOPE)
91 changes: 91 additions & 0 deletions src/core/functions/prompt_builder.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#include <inja/inja.hpp>
#include <flockmtl/core/functions/prompt_builder.hpp>

namespace flockmtl {
namespace core {

nlohmann::json GetMaxLengthValues(const std::vector<nlohmann::json> &params) {
nlohmann::json attr_to_max_token_length;

for (const auto &json_obj : params) {
for (const auto &item : json_obj.items()) {
auto key = item.key();
auto value_str = static_cast<string>(item.value());
int length = value_str.length();

if (attr_to_max_token_length.contains(key)) {
auto current_max_value_str = static_cast<string>(attr_to_max_token_length[key]);
if (current_max_value_str.length() < length) {
attr_to_max_token_length[key] = item.value();
}
} else {
attr_to_max_token_length[key] = item.value();
}
}
}

return attr_to_max_token_length;
}

std::string CombineValues(const nlohmann::json &json_obj) {
std::string combined;
for (const auto &item : json_obj.items()) {
combined += static_cast<string>(item.value()) + " ";
}

if (!combined.empty()) {
combined.pop_back();
}
return combined;
}

std::vector<std::string> ConstructPrompts(std::vector<nlohmann::json> &unique_rows, Connection &con,
std::string user_prompt_name, const std::string &llm_template,
int model_max_tokens) {
inja::Environment env;

auto query_result =
con.Query("SELECT prompt FROM flockmtl_config.FLOCKMTL_PROMPT_INTERNAL_TABLE WHERE prompt_name = '" +
user_prompt_name + "'");

if (query_result->RowCount() == 0) {
throw std::runtime_error("Prompt not found");
}

auto user_prompt = query_result->GetValue(0, 0).ToString();
int row_tokens = 0;
row_tokens += Tiktoken::GetNumTokens(user_prompt);
auto max_length_values = GetMaxLengthValues(unique_rows);
auto combined_values = CombineValues(max_length_values);
row_tokens += Tiktoken::GetNumTokens(combined_values);

std::vector<std::string> prompts;

if (row_tokens > model_max_tokens) {
throw std::runtime_error("The total number of tokens in the prompt exceeds the model's maximum token limit");
} else {
auto template_tokens = Tiktoken::GetNumTokens(llm_template);
auto max_tokens_for_rows = model_max_tokens - template_tokens;
auto max_chunk_size = max_tokens_for_rows / row_tokens;
auto chunk_size = max_chunk_size > static_cast<int>(unique_rows.size()) ? static_cast<int>(unique_rows.size())
: max_chunk_size;
auto num_chunks = static_cast<int>(std::ceil(static_cast<double>(unique_rows.size()) / chunk_size));

for (int i = 0; i < num_chunks; ++i) {
nlohmann::json data;
data["prompts"] = user_prompt;

for (int j = 0; j < chunk_size; ++j) {
data["rows"].push_back(unique_rows[i + j]);
}

auto prompt = env.render(llm_template, data);
prompts.push_back(prompt);
}
}

return prompts;
}

} // namespace core
} // namespace flockmtl
95 changes: 3 additions & 92 deletions src/core/functions/scalar/llm_complete.cpp
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
#include <functional>
#include <inja/inja.hpp>
#include <iostream>
#include <flockmtl/core/functions/prompt_builder.hpp>
#include <flockmtl/common.hpp>
#include <flockmtl/core/functions/scalar.hpp>
#include <flockmtl/core/model_manager/model_manager.hpp>
#include <flockmtl/core/model_manager/openai.hpp>
#include <flockmtl/core/model_manager/tiktoken.hpp>
#include <flockmtl/core/parser/llm_response.hpp>
#include <flockmtl/core/parser/scalar.hpp>
#include <flockmtl_extension.hpp>
#include <nlohmann/json.hpp>
#include <sstream>
#include <string>
Expand All @@ -17,93 +14,6 @@
namespace flockmtl {
namespace core {

template <typename T>
std::string ToString3(const T &value) {
std::ostringstream oss;
oss << value;
return oss.str();
}

nlohmann::json GetMaxLengthValues3(const std::vector<nlohmann::json> &params) {
nlohmann::json attr_to_max_token_length;

for (const auto &json_obj : params) {
for (const auto &item : json_obj.items()) {
auto key = item.key();
auto value_str = ToString3(item.value());
int length = value_str.length();

if (attr_to_max_token_length.contains(key)) {
auto current_max_value_str = ToString3(attr_to_max_token_length[key]);
if (current_max_value_str.length() < length) {
attr_to_max_token_length[key] = item.value();
}
} else {
attr_to_max_token_length[key] = item.value();
}
}
}

return attr_to_max_token_length;
}

std::string combine_values3(const nlohmann::json &json_obj) {
std::string combined;
for (const auto &item : json_obj.items()) {
combined += ToString3(item.value()) + " ";
}

if (!combined.empty()) {
combined.pop_back();
}
return combined;
}

inline std::vector<std::string> ConstructPrompts3(std::vector<nlohmann::json> &unique_rows, Connection &con,
std::string prompt_name, int model_max_tokens = 4096) {
inja::Environment env;

auto query_result = con.Query(
"SELECT prompt FROM flockmtl_config.FLOCKMTL_PROMPT_INTERNAL_TABLE WHERE prompt_name = '" + prompt_name + "'");

if (query_result->RowCount() == 0) {
throw std::runtime_error("Prompt not found");
}

auto template_str = query_result->GetValue(0, 0).ToString();
auto row_tokens = Tiktoken::GetNumTokens(template_str);
auto max_length_values = GetMaxLengthValues3(unique_rows);
auto combined_values = combine_values3(max_length_values);
row_tokens += Tiktoken::GetNumTokens(combined_values);

std::vector<std::string> prompts;

if (row_tokens > model_max_tokens) {
throw std::runtime_error("The total number of tokens in the prompt exceeds the model's maximum token limit");
} else {
auto template_tokens = Tiktoken::GetNumTokens(llm_complete_prompt_template);
auto max_tokens_for_rows = model_max_tokens - template_tokens;
auto max_chunk_size = max_tokens_for_rows / row_tokens;
auto chunk_size = max_chunk_size > static_cast<int>(unique_rows.size()) ? static_cast<int>(unique_rows.size())
: max_chunk_size;
auto num_chunks = static_cast<int>(std::ceil(static_cast<double>(unique_rows.size()) / chunk_size));

for (int i = 0; i < num_chunks; ++i) {
nlohmann::json data;
data["prompts"] = template_str;

for (int j = 0; j < chunk_size; ++j) {
data["rows"].push_back(unique_rows[i + j]);
}

auto prompt = env.render(llm_complete_prompt_template, data);
prompts.push_back(prompt);
}
}

return prompts;
}

static void LlmCompleteScalarFunction(DataChunk &args, ExpressionState &state, Vector &result) {
Connection con(*state.GetContext().db);
CoreScalarParsers::LlmCompleteScalarParser(args);
Expand Down Expand Up @@ -137,7 +47,8 @@ static void LlmCompleteScalarFunction(DataChunk &args, ExpressionState &state, V
} else {
auto tuples = CoreScalarParsers::Struct2Json(args.data[2], args.size());

auto prompts = ConstructPrompts3(tuples, con, args.data[0].GetValue(0).ToString(), model_max_tokens);
auto prompts = ConstructPrompts(tuples, con, args.data[0].GetValue(0).ToString(), llm_complete_prompt_template,
model_max_tokens);
nlohmann::json settings;
if (args.ColumnCount() == 4) {
settings = CoreScalarParsers::Struct2Json(args.data[3], 1)[0];
Expand Down
92 changes: 3 additions & 89 deletions src/core/functions/scalar/llm_complete_json.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include <functional>
#include <inja/inja.hpp>
#include <iostream>
#include <flockmtl/core/functions/prompt_builder.hpp>
#include <flockmtl/common.hpp>
#include <flockmtl/core/functions/scalar.hpp>
#include <flockmtl/core/model_manager/model_manager.hpp>
Expand All @@ -17,93 +17,6 @@
namespace flockmtl {
namespace core {

template <typename T>
std::string ToString(const T &value) {
std::ostringstream oss;
oss << value;
return oss.str();
}

nlohmann::json GetMaxLengthValues(const std::vector<nlohmann::json> &params) {
nlohmann::json attr_to_max_token_length;

for (const auto &json_obj : params) {
for (const auto &item : json_obj.items()) {
auto key = item.key();
auto value_str = ToString(item.value());
int length = value_str.length();

if (attr_to_max_token_length.contains(key)) {
auto current_max_value_str = ToString(attr_to_max_token_length[key]);
if (current_max_value_str.length() < length) {
attr_to_max_token_length[key] = item.value();
}
} else {
attr_to_max_token_length[key] = item.value();
}
}
}

return attr_to_max_token_length;
}

std::string combine_values(const nlohmann::json &json_obj) {
std::string combined;
for (const auto &item : json_obj.items()) {
combined += ToString(item.value()) + " ";
}

if (!combined.empty()) {
combined.pop_back();
}
return combined;
}

inline std::vector<std::string> ConstructPrompts(std::vector<nlohmann::json> &unique_rows, Connection &con,
std::string prompt_name, int model_max_tokens = 4096) {
inja::Environment env;

auto query_result = con.Query(
"SELECT prompt FROM flockmtl_config.FLOCKMTL_PROMPT_INTERNAL_TABLE WHERE prompt_name = '" + prompt_name + "'");

if (query_result->RowCount() == 0) {
throw std::runtime_error("Prompt not found");
}

auto template_str = query_result->GetValue(0, 0).ToString();
auto row_tokens = Tiktoken::GetNumTokens(template_str);
auto max_length_values = GetMaxLengthValues(unique_rows);
auto combined_values = combine_values(max_length_values);
row_tokens += Tiktoken::GetNumTokens(combined_values);

std::vector<std::string> prompts;

if (row_tokens > model_max_tokens) {
throw std::runtime_error("The total number of tokens in the prompt exceeds the model's maximum token limit");
} else {
auto template_tokens = Tiktoken::GetNumTokens(llm_complete_json_prompt_template);
auto max_tokens_for_rows = model_max_tokens - template_tokens;
auto max_chunk_size = max_tokens_for_rows / row_tokens;
auto chunk_size = max_chunk_size > static_cast<int>(unique_rows.size()) ? static_cast<int>(unique_rows.size())
: max_chunk_size;
auto num_chunks = static_cast<int>(std::ceil(static_cast<double>(unique_rows.size()) / chunk_size));

for (int i = 0; i < num_chunks; ++i) {
nlohmann::json data;
data["prompts"] = template_str;

for (int j = 0; j < chunk_size; ++j) {
data["rows"].push_back(unique_rows[i + j]);
}

auto prompt = env.render(llm_complete_json_prompt_template, data);
prompts.push_back(prompt);
}
}

return prompts;
}

static void LlmCompleteJsonScalarFunction(DataChunk &args, ExpressionState &state, Vector &result) {
Connection con(*state.GetContext().db);
CoreScalarParsers::LlmCompleteJsonScalarParser(args);
Expand Down Expand Up @@ -137,7 +50,8 @@ static void LlmCompleteJsonScalarFunction(DataChunk &args, ExpressionState &stat
} else {
auto tuples = CoreScalarParsers::Struct2Json(args.data[2], args.size());

auto prompts = ConstructPrompts(tuples, con, args.data[0].GetValue(0).ToString(), model_max_tokens);
auto prompts = ConstructPrompts(tuples, con, args.data[0].GetValue(0).ToString(),
llm_complete_json_prompt_template, model_max_tokens);

nlohmann::json settings;
if (args.ColumnCount() == 4) {
Expand Down
Loading

0 comments on commit f335913

Please sign in to comment.