Skip to content

Commit

Permalink
improve settings
Browse files Browse the repository at this point in the history
  • Loading branch information
lmangani committed Oct 22, 2024
1 parent 07243b8 commit 6c48a87
Showing 1 changed file with 85 additions and 42 deletions.
127 changes: 85 additions & 42 deletions src/open_prompt_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@

#include <string>
#include <sstream>
#include <mutex>
#include <iostream>


namespace duckdb {

Expand Down Expand Up @@ -88,49 +91,75 @@ static void HandleHttpError(const duckdb_httplib_openssl::Result &res, const std


// Open Prompt
static std::string api_url = "http://localhost:11434/v1/chat/completions";
static std::string api_token = ""; // Store your API token here
static std::string model_name = "llama2"; // Default model
// Global settings
static std::string api_url = "http://localhost:11434/v1/chat/completions";
static std::string api_token = ""; // Store your API token here
static std::string model_name = "llama2"; // Default model
static std::mutex settings_mutex; // Mutex for protecting global settings

// Retrieve the API URL from the stored settings
static std::string GetApiUrl() {
return api_url.empty() ? "http://localhost:11434/v1/chat/completions" : api_url;
}
// Function to set API token
void SetApiToken(const std::string &token) {
std::lock_guard<std::mutex> guard(settings_mutex);
if (token.empty()) {
throw std::invalid_argument("API token cannot be empty.");
}
api_token = token;
std::cerr << "API token set to: " << api_token << std::endl; // Debugging output
}

// Retrieve the API token from the stored settings
static std::string GetApiToken() {
return api_token;
}
// Function to set API URL
void SetApiUrl(const std::string &url) {
std::lock_guard<std::mutex> guard(settings_mutex);
if (url.empty()) {
throw std::invalid_argument("URL cannot be empty.");
}
api_url = url;
std::cerr << "API URL set to: " << api_url << std::endl; // Debugging output
}

// Retrieve the model name from the stored settings
static std::string GetModelName() {
return model_name.empty() ? "llama2" : model_name;
}
// Function to set model name
void SetModelName(const std::string &model) {
std::lock_guard<std::mutex> guard(settings_mutex);
if (model.empty()) {
throw std::invalid_argument("Model name cannot be empty.");
}
model_name = model;
std::cerr << "Model name set to: " << model_name << std::endl; // Debugging output
}

// Function to set API token
void SetApiToken(const std::string &token) {
api_token = token;
}
// Retrieve the API URL from the stored settings
static std::string GetApiUrl() {
std::lock_guard<std::mutex> guard(settings_mutex);
return api_url.empty() ? "http://localhost:11434/v1/chat/completions" : api_url;
}

// Function to set API URL
void SetApiUrl(const std::string &url) {
api_url = url;
}
// Retrieve the API token from the stored settings
static std::string GetApiToken() {
std::lock_guard<std::mutex> guard(settings_mutex);
return api_token;
}

// Function to set model name
void SetModelName(const std::string &model) {
model_name = model;
}
// Retrieve the model name from the stored settings
static std::string GetModelName() {
std::lock_guard<std::mutex> guard(settings_mutex);
return model_name.empty() ? "llama2" : model_name;
}

// Open Prompt Function
static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, Vector &result) {
D_ASSERT(args.data.size() == 1); // Expecting only the prompt string
D_ASSERT(args.data.size() == 2); // Expecting the prompt and model name

UnaryExecutor::Execute<string_t, string_t>(args.data[0], result, args.size(),
[&](string_t user_prompt) {
std::string api_url = GetApiUrl(); // Retrieve the API URL from settings
std::string api_token = GetApiToken(); // Retrieve the API Token from settings
std::string model_name = GetModelName(); // Retrieve the model name from settings
std::string model_name;

if (!args.data[1].GetValue(0).IsNull()) {
model_name = args.data[1].GetValue(0).ToString(); // Use passed model name
} else {
model_name = GetModelName(); // Use the default model if none is provided
}

// Manually construct the JSON body as a string
std::string request_body = "{";
Expand Down Expand Up @@ -182,35 +211,49 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V
});
}


static void LoadInternal(DatabaseInstance &instance) {
// Register open_prompt function
// Register open_prompt function with two arguments: prompt and model
ScalarFunctionSet open_prompt("open_prompt");
open_prompt.AddFunction(ScalarFunction(
{LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, OpenPromptRequestFunction));
ExtensionUtil::RegisterFunction(instance, open_prompt);

// Function to set API token
// Other set_* functions remain the same as before
ExtensionUtil::RegisterFunction(instance, ScalarFunction(
"set_api_token", {LogicalType::VARCHAR}, LogicalType::VARCHAR, // Change here
"set_api_token", {LogicalType::VARCHAR}, LogicalType::VARCHAR,
[](DataChunk &args, ExpressionState &state, Vector &result) {
SetApiToken(args.data[0].GetValue(0).ToString());
return StringVector::AddString(result, "API token set successfully.");
try {
auto token = args.data[0].GetValue(0).ToString();
SetApiToken(token);
return StringVector::AddString(result, "API token set successfully.");
} catch (std::exception &e) {
return StringVector::AddString(result, "Failed to set API token: " + std::string(e.what()));
}
}));

// Function to set API URL
ExtensionUtil::RegisterFunction(instance, ScalarFunction(
"set_api_url", {LogicalType::VARCHAR}, LogicalType::VARCHAR, // Change here
"set_api_url", {LogicalType::VARCHAR}, LogicalType::VARCHAR,
[](DataChunk &args, ExpressionState &state, Vector &result) {
SetApiUrl(args.data[0].GetValue(0).ToString());
return StringVector::AddString(result, "API URL set successfully.");
try {
auto new_url = args.data[0].GetValue(0).ToString();
SetApiUrl(new_url);
return StringVector::AddString(result, "API URL set successfully.");
} catch (std::exception &e) {
return StringVector::AddString(result, "Failed to set API URL: " + std::string(e.what()));
}
}));

// Function to set model name
ExtensionUtil::RegisterFunction(instance, ScalarFunction(
"set_model_name", {LogicalType::VARCHAR}, LogicalType::VARCHAR, // Change here
"set_model_name", {LogicalType::VARCHAR}, LogicalType::VARCHAR,
[](DataChunk &args, ExpressionState &state, Vector &result) {
SetModelName(args.data[0].GetValue(0).ToString());
return StringVector::AddString(result, "Model name set successfully.");
try {
auto model = args.data[0].GetValue(0).ToString();
SetModelName(model);
return StringVector::AddString(result, "Model name set successfully.");
} catch (std::exception &e) {
return StringVector::AddString(result, "Failed to set model name: " + std::string(e.what()));
}
}));
}

Expand Down

0 comments on commit 6c48a87

Please sign in to comment.