Skip to content

Commit

Permalink
feat: Add BatchExecuteCode function to UdfClient
Browse files Browse the repository at this point in the history
Bug: 369601693
Change-Id: I3fde0729dc25b6b22d23d80b8a60499ee4415899
GitOrigin-RevId: f1830a46e0587d0367513f26d8429c44d3409476
  • Loading branch information
lusayaa authored and copybara-github committed Dec 9, 2024
1 parent 6e47b99 commit 18260f2
Show file tree
Hide file tree
Showing 5 changed files with 217 additions and 7 deletions.
6 changes: 6 additions & 0 deletions components/udf/mocks.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ class MockUdfClient : public UdfClient {
const google::protobuf::RepeatedPtrField<UDFArgument>&,
ExecutionMetadata& execution_metadata),
(const, override));
MOCK_METHOD((absl::StatusOr<absl::flat_hash_map<int32_t, std::string>>),
BatchExecuteCode,
(const RequestContextFactory& request_context_factory,
(absl::flat_hash_map<int32_t, UDFInput> & udf_input_map),
ExecutionMetadata& metadata),
(const, override));
MOCK_METHOD((absl::Status), Stop, (), (override));
MOCK_METHOD((absl::Status), SetCodeObject,
(CodeConfig, privacy_sandbox::server_common::log::PSLogContext&),
Expand Down
11 changes: 11 additions & 0 deletions components/udf/noop_udf_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,17 @@ class NoopUdfClientImpl : public UdfClient {
return "";
}

absl::StatusOr<absl::flat_hash_map<int32_t, std::string>> BatchExecuteCode(
const RequestContextFactory& request_context_factory,
absl::flat_hash_map<int32_t, UDFInput>& udf_input_map,
ExecutionMetadata& metadata) const {
absl::flat_hash_map<int32_t, std::string> response;
for (auto&& [k, v] : udf_input_map) {
response[k] = "";
}
return response;
}

absl::Status Stop() { return absl::OkStatus(); }

absl::Status SetCodeObject(
Expand Down
64 changes: 57 additions & 7 deletions components/udf/udf_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "components/udf/udf_client.h"

#include <functional>
#include <future>
#include <memory>
#include <string>
#include <tuple>
Expand Down Expand Up @@ -159,6 +160,55 @@ class UdfClientImpl : public UdfClient {
absl::ToInt64Milliseconds(latency_recorder.GetLatency());
return *result;
}

absl::StatusOr<absl::flat_hash_map<int32_t, std::string>> BatchExecuteCode(
const RequestContextFactory& request_context_factory,
absl::flat_hash_map<int32_t, UDFInput>& udf_input_map,
ExecutionMetadata& metadata) const {
absl::flat_hash_map<int32_t, std::string> results;
if (udf_input_map.empty()) {
PS_VLOG(5, request_context_factory.Get().GetPSLogContext())
<< "UDF input map is empty. Not executing any UDFs.";
return results;
}

absl::flat_hash_map<int32_t, std::future<absl::StatusOr<std::string>>>
responses;
metadata.custom_code_total_execution_time_micros = 0;
for (auto&& [id, udf_input] : udf_input_map) {
responses[id] = std::async(
std::launch::async,
[this, &request_context_factory, &metadata](UDFInput&& udf_input) {
ExecutionMetadata single_run_metadata;
auto result =
this->ExecuteCode(request_context_factory,
std::move(udf_input.execution_metadata),
udf_input.arguments, single_run_metadata);
// Record the longest UDF execution time across all parallel
// executions
metadata.custom_code_total_execution_time_micros = std::max(
metadata.custom_code_total_execution_time_micros,
single_run_metadata.custom_code_total_execution_time_micros);
return result;
},
std::move(udf_input));
}

// Process responses
for (auto&& [id, response] : responses) {
auto result = response.get();

if (result.ok()) {
results[id] = std::move(result.value());
} else {
PS_LOG(ERROR, request_context_factory.Get().GetPSLogContext())
<< "UDF Execution failed for partition id " << id << ": "
<< result.status();
}
}
return results;
}

absl::Status Init() { return roma_service_.Init(); }

absl::Status Stop() { return roma_service_.Stop(); }
Expand Down Expand Up @@ -256,13 +306,13 @@ class UdfClientImpl : public UdfClient {
const absl::Duration udf_timeout_;
const absl::Duration udf_update_timeout_;
int udf_min_log_level_;
// Per b/299667930, RomaService has been extended to support metadata storage
// as a side effect of RomaService::Execute(), making it no longer const.
// However, UDFClient::ExecuteCode() remains logically const, so RomaService
// is marked as mutable to allow usage within UDFClient::ExecuteCode(). For
// concerns about mutable or go/totw/174, RomaService is thread-safe, so
// losing the thread-safety of usage within a const function is a lesser
// concern.
// Per b/299667930, RomaService has been extended to support metadata
// storage as a side effect of RomaService::Execute(), making it no longer
// const. However, UDFClient::ExecuteCode() remains logically const, so
// RomaService is marked as mutable to allow usage within
// UDFClient::ExecuteCode(). For concerns about mutable or go/totw/174,
// RomaService is thread-safe, so losing the thread-safety of usage within a
// const function is a lesser concern.
mutable RomaService<std::weak_ptr<RequestContext>> roma_service_;
};

Expand Down
13 changes: 13 additions & 0 deletions components/udf/udf_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,13 @@

namespace kv_server {

struct UDFInput {
UDFExecutionMetadata execution_metadata;
google::protobuf::RepeatedPtrField<UDFArgument> arguments;
};

struct ExecutionMetadata {
// Total time for all custom code to execute
std::optional<int64_t> custom_code_total_execution_time_micros;
};

Expand All @@ -60,6 +66,13 @@ class UdfClient {
const google::protobuf::RepeatedPtrField<UDFArgument>& arguments,
ExecutionMetadata& metadata) const = 0;

// Executes multiple UDFs in parallel. Code object must be set before making
// this call.
virtual absl::StatusOr<absl::flat_hash_map<int32_t, std::string>>
BatchExecuteCode(const RequestContextFactory& request_context_factory,
absl::flat_hash_map<int32_t, UDFInput>& udf_input_map,
ExecutionMetadata& metadata) const = 0;

virtual absl::Status Stop() = 0;

// Sets the code object that will be used for UDF execution
Expand Down
130 changes: 130 additions & 0 deletions components/udf/udf_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,18 @@ using testing::Return;

namespace kv_server {
namespace {

constexpr std::string_view kEmptyMetadata = R"(
request_metadata {
fields {
key: "hostname"
value {
string_value: ""
}
}
}
)";

absl::StatusOr<std::unique_ptr<UdfClient>> CreateUdfClient() {
Config<std::weak_ptr<RequestContext>> config;
config.number_of_workers = 1;
Expand Down Expand Up @@ -1104,5 +1116,123 @@ TEST_F(UdfClientTest, JsCallsLogCustomMetricFailedToLogError) {
EXPECT_THAT(metrics_logging_outcome, ContainsRegex("Failed to log metrics"));
}

TEST_F(UdfClientTest, BatchExecuteCodeSuccess) {
auto udf_client = CreateUdfClient();
EXPECT_TRUE(udf_client.ok());

absl::Status code_obj_status = udf_client.value()->SetCodeObject(CodeConfig{
.js = "hello = (metadata, data) => 'Hello world! ' + "
"JSON.stringify(metadata) + JSON.stringify(data);",
.udf_handler_name = "hello",
.logical_commit_time = 1,
.version = 1,
});
EXPECT_TRUE(code_obj_status.ok());

UDFExecutionMetadata udf_metadata;
TextFormat::ParseFromString(kEmptyMetadata, &udf_metadata);

google::protobuf::RepeatedPtrField<UDFArgument> args1;
args1.Add([] {
UDFArgument arg;
arg.mutable_tags()->add_values()->set_string_value("tag1");
arg.mutable_data()->set_string_value("key1");
return arg;
}());
google::protobuf::RepeatedPtrField<UDFArgument> args2;
args2.Add([] {
UDFArgument arg;
arg.mutable_tags()->add_values()->set_string_value("tag2");
arg.mutable_data()->set_string_value("key2");
return arg;
}());

absl::flat_hash_map<int32_t, UDFInput> input;
input[1] = {.arguments = args1};
input[2] = {.execution_metadata = udf_metadata, .arguments = args2};
auto result = udf_client.value()->BatchExecuteCode(
*request_context_factory_, input, execution_metadata_);
ASSERT_TRUE(result.ok());
auto udf_outputs = std::move(result.value());
EXPECT_EQ(udf_outputs.size(), 2);
EXPECT_EQ(
udf_outputs[1],
R"("Hello world! {\"udfInterfaceVersion\":1}{\"tags\":[\"tag1\"],\"data\":\"key1\"}")");
EXPECT_EQ(
udf_outputs[2],
R"("Hello world! {\"udfInterfaceVersion\":1,\"requestMetadata\":{\"hostname\":\"\"}}{\"tags\":[\"tag2\"],\"data\":\"key2\"}")");

absl::Status stop = udf_client.value()->Stop();
EXPECT_TRUE(stop.ok());
}

TEST_F(UdfClientTest, BatchExecuteCodeIgnoresFailedPartition) {
auto udf_client = CreateUdfClient();
EXPECT_TRUE(udf_client.ok());

absl::Status code_obj_status = udf_client.value()->SetCodeObject(CodeConfig{
.js =
R"js(function hello(metadata, data) {
if(data.data == "valid_key") {return 'Hello world!';}
throw new Error('Oh no!');
})js",
.udf_handler_name = "hello",
.logical_commit_time = 1,
.version = 1,
});
EXPECT_TRUE(code_obj_status.ok());

UDFExecutionMetadata udf_metadata;
TextFormat::ParseFromString(kEmptyMetadata, &udf_metadata);

google::protobuf::RepeatedPtrField<UDFArgument> args1;
args1.Add([] {
UDFArgument arg;
arg.mutable_tags()->add_values()->set_string_value("some_tag");
arg.mutable_data()->set_string_value("valid_key");
return arg;
}());
google::protobuf::RepeatedPtrField<UDFArgument> args2;
args2.Add([] {
UDFArgument arg;
arg.mutable_data()->set_string_value("invalid key");
return arg;
}());

absl::flat_hash_map<int32_t, UDFInput> input;
input[1] = {.arguments = args1};
input[2] = {.arguments = args2};
auto result = udf_client.value()->BatchExecuteCode(
*request_context_factory_, input, execution_metadata_);
ASSERT_TRUE(result.ok());
auto udf_outputs = std::move(result.value());
EXPECT_EQ(udf_outputs.size(), 1);
EXPECT_EQ(udf_outputs[1], R"("Hello world!")");

absl::Status stop = udf_client.value()->Stop();
EXPECT_TRUE(stop.ok());
}

TEST_F(UdfClientTest, BatchExecuteCodeEmptyReturnsSuccess) {
auto udf_client = CreateUdfClient();
EXPECT_TRUE(udf_client.ok());

absl::Status code_obj_status = udf_client.value()->SetCodeObject(CodeConfig{
.js = "hello = (metadata, data) => 'Hello world! ' + "
"JSON.stringify(metadata) + JSON.stringify(data);",
.udf_handler_name = "hello",
.logical_commit_time = 1,
.version = 1,
});
EXPECT_TRUE(code_obj_status.ok());

absl::flat_hash_map<int32_t, UDFInput> input;
auto result = udf_client.value()->BatchExecuteCode(
*request_context_factory_, input, execution_metadata_);
ASSERT_TRUE(result.ok());
EXPECT_EQ(result->size(), 0);
absl::Status stop = udf_client.value()->Stop();
EXPECT_TRUE(stop.ok());
}
} // namespace
} // namespace kv_server

0 comments on commit 18260f2

Please sign in to comment.