Skip to content

Commit

Permalink
change prompt_name for aggregate functions into search_query (#47)
Browse files Browse the repository at this point in the history
  • Loading branch information
dorbanianas authored Oct 30, 2024
1 parent ff128d6 commit 6d95fe9
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 39 deletions.
28 changes: 14 additions & 14 deletions src/core/functions/aggregate/llm_agg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,32 +18,32 @@ void LlmAggState::Combine(const LlmAggState &source) {
value = std::move(source.value);
}

LlmMinOrMax::LlmMinOrMax(std::string &model, int model_context_size, std::string &user_prompt,
LlmMinOrMax::LlmMinOrMax(std::string &model, int model_context_size, std::string &search_query,
std::string &llm_min_or_max_template)
: model(model), model_context_size(model_context_size), user_prompt(user_prompt),
: model(model), model_context_size(model_context_size), search_query(search_query),
llm_min_or_max_template(llm_min_or_max_template) {

auto num_tokens_meta_and_user_prompts = calculateFixedTokens();
auto num_tokens_meta_and_search_query = calculateFixedTokens();

if (num_tokens_meta_and_user_prompts > model_context_size) {
if (num_tokens_meta_and_search_query > model_context_size) {
throw std::runtime_error("Fixed tokens exceed model context size");
}

available_tokens = model_context_size - num_tokens_meta_and_user_prompts;
available_tokens = model_context_size - num_tokens_meta_and_search_query;
}

int LlmMinOrMax::calculateFixedTokens() const {
int num_tokens_meta_and_user_prompts = 0;
num_tokens_meta_and_user_prompts += Tiktoken::GetNumTokens(user_prompt);
num_tokens_meta_and_user_prompts += Tiktoken::GetNumTokens(llm_min_or_max_template);
return num_tokens_meta_and_user_prompts;
int num_tokens_meta_and_search_query = 0;
num_tokens_meta_and_search_query += Tiktoken::GetNumTokens(search_query);
num_tokens_meta_and_search_query += Tiktoken::GetNumTokens(llm_min_or_max_template);
return num_tokens_meta_and_search_query;
}

nlohmann::json LlmMinOrMax::GetMaxOrMinTupleId(const nlohmann::json &tuples) {
inja::Environment env;
nlohmann::json data;
data["tuples"] = tuples;
data["user_prompt"] = user_prompt;
data["search_query"] = search_query;
auto prompt = env.render(llm_min_or_max_template, data);

nlohmann::json settings;
Expand Down Expand Up @@ -96,7 +96,7 @@ nlohmann::json LlmMinOrMax::Evaluate(nlohmann::json &tuples) {

// Static member initialization
std::string LlmAggOperation::model_name;
std::string LlmAggOperation::prompt_name;
std::string LlmAggOperation::search_query;
std::unordered_map<void *, std::shared_ptr<LlmAggState>> LlmAggOperation::state_map;


Expand All @@ -112,7 +112,7 @@ void LlmAggOperation::Initialize(const AggregateFunction &, data_ptr_t state_p)

void LlmAggOperation::Operation(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, Vector &states,
idx_t count) {
prompt_name = inputs[0].GetValue(0).ToString();
search_query = inputs[0].GetValue(0).ToString();
model_name = inputs[1].GetValue(0).ToString();

if (inputs[2].GetType().id() != LogicalTypeId::STRUCT) {
Expand Down Expand Up @@ -174,7 +174,7 @@ void LlmAggOperation::FinalizeResults(Vector &states, AggregateInputData &aggr_i
tuples_with_ids.push_back(tuple_with_id);
}

LlmMinOrMax llm_min_or_max(model, model_context_size, prompt_name, llm_prompt_template);
LlmMinOrMax llm_min_or_max(model, model_context_size, search_query, llm_prompt_template);
auto response = llm_min_or_max.Evaluate(tuples_with_ids);
result.SetValue(idx, response.dump());
}
Expand All @@ -194,7 +194,7 @@ void LlmAggOperation::MinOrMaxFinalize<MinOrMax::MAX>(Vector &states, AggregateI

void LlmAggOperation::SimpleUpdate(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count,
data_ptr_t state_p, idx_t count) {
prompt_name = inputs[0].GetValue(0).ToString();
search_query = inputs[0].GetValue(0).ToString();
model_name = inputs[1].GetValue(0).ToString();
auto tuples = CastVectorOfStructsToJson(inputs[2], count);

Expand Down
30 changes: 11 additions & 19 deletions src/core/functions/aggregate/llm_rerank.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@
namespace flockmtl {
namespace core {

LlmReranker::LlmReranker(std::string& model, int model_context_size, std::string& user_prompt, std::string& llm_reranking_template)
: model(model), model_context_size(model_context_size), user_prompt(user_prompt), llm_reranking_template(llm_reranking_template) {
LlmReranker::LlmReranker(std::string& model, int model_context_size, std::string& search_query, std::string& llm_reranking_template)
: model(model), model_context_size(model_context_size), search_query(search_query), llm_reranking_template(llm_reranking_template) {

auto num_tokens_meta_and_user_prompts = CalculateFixedTokens();
auto num_tokens_meta_and_search_query = CalculateFixedTokens();

if (num_tokens_meta_and_user_prompts > model_context_size) {
if (num_tokens_meta_and_search_query > model_context_size) {
throw std::runtime_error("Fixed tokens exceed model context size");
}

available_tokens = model_context_size - num_tokens_meta_and_user_prompts;
available_tokens = model_context_size - num_tokens_meta_and_search_query;
};

nlohmann::json LlmReranker::SlidingWindowRerank(nlohmann::json& tuples) {
Expand Down Expand Up @@ -77,17 +77,17 @@ nlohmann::json LlmReranker::SlidingWindowRerank(nlohmann::json& tuples) {
}

int LlmReranker::CalculateFixedTokens() const {
int num_tokens_meta_and_user_prompts = 0;
num_tokens_meta_and_user_prompts += Tiktoken::GetNumTokens(user_prompt);
num_tokens_meta_and_user_prompts += Tiktoken::GetNumTokens(llm_reranking_template);
return num_tokens_meta_and_user_prompts;
int num_tokens_meta_and_search_query = 0;
num_tokens_meta_and_search_query += Tiktoken::GetNumTokens(search_query);
num_tokens_meta_and_search_query += Tiktoken::GetNumTokens(llm_reranking_template);
return num_tokens_meta_and_search_query;
}

nlohmann::json LlmReranker::LlmRerankWithSlidingWindow(const nlohmann::json& tuples) {
inja::Environment env;
nlohmann::json data;
data["tuples"] = tuples;
data["user_prompt"] = user_prompt;
data["search_query"] = search_query;
auto prompt = env.render(llm_reranking_template, data);

nlohmann::json settings;
Expand Down Expand Up @@ -115,14 +115,6 @@ void LlmAggOperation::RerankerFinalize(Vector &states, AggregateInputData &aggr_
auto model = query_result->GetValue(0, 0).ToString();
auto model_context_size = query_result->GetValue(1, 0).GetValue<int>();

query_result = CoreModule::GetConnection().Query("SELECT prompt FROM flockmtl_config.FLOCKMTL_PROMPT_INTERNAL_TABLE WHERE prompt_name = '" +
prompt_name + "'");

if (query_result->RowCount() == 0) {
throw std::runtime_error("Prompt not found");
}

auto user_prompt = query_result->GetValue(0, 0).ToString();
auto llm_rerank_prompt_template_str = std::string(llm_rerank_prompt_template);

auto tuples_with_ids = nlohmann::json::array();
Expand All @@ -133,7 +125,7 @@ void LlmAggOperation::RerankerFinalize(Vector &states, AggregateInputData &aggr_
tuples_with_ids.push_back(tuple_with_id);
}

LlmReranker llm_reranker(model, model_context_size, user_prompt, llm_rerank_prompt_template_str);
LlmReranker llm_reranker(model, model_context_size, search_query, llm_rerank_prompt_template_str);
auto reranked_tuples = llm_reranker.SlidingWindowRerank(tuples_with_ids);

result.SetValue(idx, reranked_tuples.dump());
Expand Down
6 changes: 3 additions & 3 deletions src/include/flockmtl/core/functions/aggregate/llm_agg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ class LlmMinOrMax {
public:
std::string model;
int model_context_size;
std::string user_prompt;
std::string search_query;
std::string llm_min_or_max_template;
int available_tokens;

LlmMinOrMax(std::string &model, int model_context_size, std::string &user_prompt,
LlmMinOrMax(std::string &model, int model_context_size, std::string &search_query,
std::string &llm_min_or_max_template);

nlohmann::json GetMaxOrMinTupleId(const nlohmann::json &tuples);
Expand All @@ -38,7 +38,7 @@ class LlmMinOrMax {

struct LlmAggOperation {
static std::string model_name;
static std::string prompt_name;
static std::string search_query;
static std::unordered_map<void *, std::shared_ptr<LlmAggState>> state_map;

static void Initialize(const AggregateFunction &, data_ptr_t state_p);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class LlmReranker {
private:
std::string model;
int model_context_size;
std::string user_prompt;
std::string search_query;
std::string llm_reranking_template;
int available_tokens;

Expand Down
2 changes: 1 addition & 1 deletion src/include/templates/llm_min_or_max_prompt_template.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ I will provide you with some tuples, each indicated by a numerical identifier []
[{{tuple.id}}] {{tuple.content}}
{% endfor %}
Search Query: {{user_prompt}}
Search Query: {{search_query}}
Select the **single {relevance} relevant tuple** from the tuples provided.The output format should be a JSON object. Only respond with the result, do not explain or add additional words.
Expand Down
2 changes: 1 addition & 1 deletion src/include/templates/llm_rerank_prompt_template.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ I will provide you with some tuples, each indicated by a numerical identifier []
[{{tuple.id}}] {{tuple.content}}
{% endfor %}
Search Query: {{user_prompt}}
Search Query: {{search_query}}
Rank the tuples above based on their relevance to the search query. All the passages should be included and listed using identifiers, in descending order of relevance. The output format should be in JSON list [id_1, ..., id_n], e.g., [22, 33, ..., 3], Only respond with the ranking results, do not say any word or explain.
Expand Down

0 comments on commit 6d95fe9

Please sign in to comment.