Skip to content

Commit

Permalink
refactor, deduplicate
Browse files Browse the repository at this point in the history
  • Loading branch information
lmangani committed Oct 20, 2024
1 parent dd65a86 commit abcde70
Showing 1 changed file with 98 additions and 146 deletions.
244 changes: 98 additions & 146 deletions src/http_client_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,141 +15,133 @@

namespace duckdb {

// Helper function to parse URL and setup client
static std::pair<duckdb_httplib_openssl::Client, std::string> SetupHttpClient(const std::string &url) {
std::string scheme, domain, path;
size_t pos = url.find("://");
std::string mod_url = url;
if (pos != std::string::npos) {
scheme = mod_url.substr(0, pos);
mod_url.erase(0, pos + 3);
}

pos = mod_url.find("/");
if (pos != std::string::npos) {
domain = mod_url.substr(0, pos);
path = mod_url.substr(pos);
} else {
domain = mod_url;
path = "/";
}

// Create client and set a reasonable timeout (e.g., 10 seconds)
duckdb_httplib_openssl::Client client(domain.c_str());
client.set_read_timeout(10, 0); // 10 seconds
client.set_follow_location(true); // Follow redirects

return std::make_pair(std::move(client), path);
}

static void HandleHttpError(const duckdb_httplib_openssl::Result &res, const std::string &request_type) {
std::string err_message = "HTTP " + request_type + " request failed. ";

switch (res.error()) {
case duckdb_httplib_openssl::Error::Connection:
err_message += "Connection error.";
break;
case duckdb_httplib_openssl::Error::BindIPAddress:
err_message += "Failed to bind IP address.";
break;
case duckdb_httplib_openssl::Error::Read:
err_message += "Error reading response.";
break;
case duckdb_httplib_openssl::Error::Write:
err_message += "Error writing request.";
break;
case duckdb_httplib_openssl::Error::ExceedRedirectCount:
err_message += "Too many redirects.";
break;
case duckdb_httplib_openssl::Error::Canceled:
err_message += "Request was canceled.";
break;
case duckdb_httplib_openssl::Error::SSLConnection:
err_message += "SSL connection failed.";
break;
case duckdb_httplib_openssl::Error::SSLLoadingCerts:
err_message += "Failed to load SSL certificates.";
break;
case duckdb_httplib_openssl::Error::SSLServerVerification:
err_message += "SSL server verification failed.";
break;
case duckdb_httplib_openssl::Error::UnsupportedMultipartBoundaryChars:
err_message += "Unsupported characters in multipart boundary.";
break;
case duckdb_httplib_openssl::Error::Compression:
err_message += "Error during compression.";
break;
default:
err_message += "Unknown error.";
break;
}
throw std::runtime_error(err_message);
}


static void HTTPGetRequestFunction(DataChunk &args, ExpressionState &state, Vector &result) {
D_ASSERT(args.data.size() == 1);

UnaryExecutor::Execute<string_t, string_t>(args.data[0], result, args.size(), [&](string_t input) {
std::string url = input.GetString();

// Parse the URL to extract the domain and path
std::string scheme, domain, path;
size_t pos = url.find("://");
if (pos != std::string::npos) {
scheme = url.substr(0, pos);
url.erase(0, pos + 3);
}

pos = url.find("/");
if (pos != std::string::npos) {
domain = url.substr(0, pos);
path = url.substr(pos);
} else {
domain = url;
path = "/";
}

// Create client and set a reasonable timeout (e.g., 10 seconds)
duckdb_httplib_openssl::Client client(domain.c_str());
client.set_read_timeout(10, 0); // 10 seconds

// Follow redirects
client.set_follow_location(true);
// Use helper to setup client and parse URL
auto client_and_path = SetupHttpClient(url);
auto &client = client_and_path.first;
auto &path = client_and_path.second;

// Make the GET request
auto res = client.Get(path.c_str());
if (res) {
if (res->status == 200) {
return StringVector::AddString(result, res->body);
} else {
throw std::runtime_error("HTTP error: " + std::to_string(res->status) + " - " + res->reason);
throw std::runtime_error("HTTP GET error: " + std::to_string(res->status) + " - " + res->reason);
}
} else {
// Handle the error case
std::string err_message = "HTTP request failed. ";

// Convert httplib error codes to a descriptive message
switch (res.error()) {
case duckdb_httplib_openssl::Error::Connection:
err_message += "Connection error.";
break;
case duckdb_httplib_openssl::Error::BindIPAddress:
err_message += "Failed to bind IP address.";
break;
case duckdb_httplib_openssl::Error::Read:
err_message += "Error reading response.";
break;
case duckdb_httplib_openssl::Error::Write:
err_message += "Error writing request.";
break;
case duckdb_httplib_openssl::Error::ExceedRedirectCount:
err_message += "Too many redirects.";
break;
case duckdb_httplib_openssl::Error::Canceled:
err_message += "Request was canceled.";
break;
case duckdb_httplib_openssl::Error::SSLConnection:
err_message += "SSL connection failed.";
break;
case duckdb_httplib_openssl::Error::SSLLoadingCerts:
err_message += "Failed to load SSL certificates.";
break;
case duckdb_httplib_openssl::Error::SSLServerVerification:
err_message += "SSL server verification failed.";
break;
case duckdb_httplib_openssl::Error::UnsupportedMultipartBoundaryChars:
err_message += "Unsupported characters in multipart boundary.";
break;
case duckdb_httplib_openssl::Error::Compression:
err_message += "Error during compression.";
break;
default:
err_message += "Unknown error.";
break;
}
throw std::runtime_error(err_message);
// Handle errors
HandleHttpError(res, "GET");
}
// Ensure a return value in case of an error
return string_t();
});
}

static void HTTPPostRequestFunction(DataChunk &args, ExpressionState &state, Vector &result) {
D_ASSERT(args.data.size() == 3);

auto &url_vector = args.data[0];
auto &headers_vector = args.data[1]; // Already passed as a serialized string
auto &body_vector = args.data[2]; // Already passed as a JSON string
auto &headers_vector = args.data[1];
auto &body_vector = args.data[2];

// Use TernaryExecutor instead of UnaryExecutor
TernaryExecutor::Execute<string_t, string_t, string_t, string_t>(
url_vector, headers_vector, body_vector, result, args.size(),
[&](string_t url, string_t headers_varchar, string_t body_varchar) {
[&](string_t url, string_t headers, string_t body) {
std::string url_str = url.GetString();

// Parse the URL to extract the domain and path
std::string scheme, domain, path;
size_t pos = url_str.find("://");
if (pos != std::string::npos) {
scheme = url_str.substr(0, pos);
url_str.erase(0, pos + 3);
}

pos = url_str.find("/");
if (pos != std::string::npos) {
domain = url_str.substr(0, pos);
path = url_str.substr(pos);
} else {
domain = url_str;
path = "/";
}

// Create the client and set a timeout
duckdb_httplib_openssl::Client client(domain.c_str());
client.set_read_timeout(10, 0); // 10 seconds
// Follow redirects
client.set_follow_location(true);

// Follow redirects for POST as well
client.set_follow_location(true);
// Use helper to setup client and parse URL
auto client_and_path = SetupHttpClient(url_str);
auto &client = client_and_path.first;
auto &path = client_and_path.second;

// Deserialize the header string into a header map
// Prepare headers
duckdb_httplib_openssl::Headers header_map;
std::istringstream header_stream(headers_varchar.GetString());
std::istringstream header_stream(headers.GetString());
std::string header;
while (std::getline(header_stream, header)) {
size_t colon_pos = header.find(':');
if (colon_pos != std::string::npos) {
std::string key = header.substr(0, colon_pos);
std::string value = header.substr(colon_pos + 1);
// Trim leading/trailing whitespace
// Trim leading and trailing whitespace
key.erase(0, key.find_first_not_of(" \t"));
key.erase(key.find_last_not_of(" \t") + 1);
value.erase(0, value.find_first_not_of(" \t"));
Expand All @@ -158,60 +150,20 @@ static void HTTPPostRequestFunction(DataChunk &args, ExpressionState &state, Vec
}
}

// Prepare the POST body (it is passed as a string)
std::string body_str = body_varchar.GetString();

// Make the POST request
auto res = client.Post(path.c_str(), header_map, body_str, "application/json");
// Make the POST request with headers and body
auto res = client.Post(path.c_str(), header_map, body.GetString(), "application/json");
if (res) {
if (res->status == 200) {
return StringVector::AddString(result, res->body);
} else {
throw std::runtime_error("HTTP error: " + std::to_string(res->status) + " - " + res->reason);
throw std::runtime_error("HTTP POST error: " + std::to_string(res->status) + " - " + res->reason);
}
} else {
// Handle the error case
std::string err_message = "HTTP POST request failed. ";
switch (res.error()) {
case duckdb_httplib_openssl::Error::Connection:
err_message += "Connection error.";
break;
case duckdb_httplib_openssl::Error::BindIPAddress:
err_message += "Failed to bind IP address.";
break;
case duckdb_httplib_openssl::Error::Read:
err_message += "Error reading response.";
break;
case duckdb_httplib_openssl::Error::Write:
err_message += "Error writing request.";
break;
case duckdb_httplib_openssl::Error::ExceedRedirectCount:
err_message += "Too many redirects.";
break;
case duckdb_httplib_openssl::Error::Canceled:
err_message += "Request was canceled.";
break;
case duckdb_httplib_openssl::Error::SSLConnection:
err_message += "SSL connection failed.";
break;
case duckdb_httplib_openssl::Error::SSLLoadingCerts:
err_message += "Failed to load SSL certificates.";
break;
case duckdb_httplib_openssl::Error::SSLServerVerification:
err_message += "SSL server verification failed.";
break;
case duckdb_httplib_openssl::Error::UnsupportedMultipartBoundaryChars:
err_message += "Unsupported characters in multipart boundary.";
break;
case duckdb_httplib_openssl::Error::Compression:
err_message += "Error during compression.";
break;
default:
err_message += "Unknown error.";
break;
}
throw std::runtime_error(err_message);
// Handle errors
HandleHttpError(res, "POST");
}
// Ensure a return value in case of an error
return string_t();
});
}

Expand Down

0 comments on commit abcde70

Please sign in to comment.