Skip to content

Commit

Permalink
implement runner-initiated split path
Browse files Browse the repository at this point in the history
  • Loading branch information
iasoon committed Nov 18, 2022
1 parent 5dd1eb4 commit e3d2d32
Show file tree
Hide file tree
Showing 3 changed files with 306 additions and 79 deletions.
237 changes: 208 additions & 29 deletions ray_beam_runner/portability/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@
import itertools
import logging
import random
import threading
import time
import typing
from typing import List
from typing import List, MutableMapping
from typing import Mapping
from typing import Optional
from typing import Generator
Expand Down Expand Up @@ -57,6 +59,7 @@ def ray_execute_bundle(
transform_buffer_coder: Mapping[str, typing.Tuple[bytes, str]],
expected_outputs: translations.DataOutput,
stage_timers: Mapping[translations.TimerFamilyId, bytes],
split_manager,
instruction_request_repr: Mapping[str, typing.Any],
dry_run=False,
) -> Generator:
Expand All @@ -83,8 +86,6 @@ def ray_execute_bundle(
runner_context, instruction_request_repr["process_descriptor_id"]
)

_send_timers(worker_handler, input_bundle, stage_timers, process_bundle_id)

input_data = {
k: _fetch_decode_data(
runner_context,
Expand All @@ -95,19 +96,30 @@ def ray_execute_bundle(
for k, objrefs in input_bundle.input_data.items()
}

for transform_id, elements in input_data.items():
data_out = worker_handler.data_conn.output_stream(
process_bundle_id, transform_id
)
for byte_stream in elements:
data_out.write(byte_stream)
data_out.close()

expect_reads: List[typing.Union[str, translations.TimerFamilyId]] = list(
expected_outputs.keys()
)
expect_reads.extend(list(stage_timers.keys()))

split_results = []
split_manager_thread = None
if split_manager:
split_manager_thread = threading.Thread(
target=_run_split_manager,
args=(
runner_context,
worker_handler,
split_manager,
input_data,
transform_buffer_coder,
instruction_request,
split_results,
),
)
split_manager_thread.start()

_send_timers(worker_handler, input_bundle, stage_timers, process_bundle_id)
_send_input_data(worker_handler, input_data, process_bundle_id)
result_future = worker_handler.control_conn.push(instruction_request)

for output in worker_handler.data_conn.input_elements(
Expand All @@ -125,6 +137,8 @@ def ray_execute_bundle(
output_buffers[expected_outputs[output.transform_id]].append(output.data)

result: beam_fn_api_pb2.InstructionResponse = result_future.get()
if split_manager_thread:
split_manager_thread.join()

if result.process_bundle.requires_finalization:
finalize_request = beam_fn_api_pb2.InstructionRequest(
Expand All @@ -151,14 +165,27 @@ def ray_execute_bundle(
process_bundle_descriptor = runner_context.worker_manager.process_bundle_descriptor(
instruction_request_repr["process_descriptor_id"]
)
delayed_applications = _retrieve_delayed_applications(

deferred_inputs = {}

_add_delayed_applications_to_deferred_inputs(
result,
process_bundle_descriptor,
runner_context,
deferred_inputs,
)

returns.append(len(delayed_applications))
for pcoll, buffer in delayed_applications.items():
_add_residuals_and_channel_splits_to_deferred_inputs(
runner_context,
input_bundle.input_data,
transform_buffer_coder,
process_bundle_descriptor,
split_results,
deferred_inputs,
)

returns.append(len(deferred_inputs))
for pcoll, buffer in deferred_inputs.items():
returns.append(pcoll)
returns.append(buffer)

Expand Down Expand Up @@ -206,37 +233,101 @@ def _get_source_transform_name(
raise RuntimeError("No IO transform feeds %s" % transform_id)


def _retrieve_delayed_applications(
def _add_delayed_application_to_deferred_inputs(
process_bundle_descriptor: beam_fn_api_pb2.ProcessBundleDescriptor,
delayed_application: beam_fn_api_pb2.DelayedBundleApplication,
deferred_inputs: MutableMapping[str, List[bytes]],
):
# TODO(pabloem): Time delay needed for streaming. For now we'll ignore it.
# time_delay = delayed_application.requested_time_delay
source_transform = _get_source_transform_name(
process_bundle_descriptor,
delayed_application.application.transform_id,
delayed_application.application.input_id,
)

if source_transform not in deferred_inputs:
deferred_inputs[source_transform] = []
deferred_inputs[source_transform].append(delayed_application.application.element)


def _add_delayed_applications_to_deferred_inputs(
bundle_result: beam_fn_api_pb2.InstructionResponse,
process_bundle_descriptor: beam_fn_api_pb2.ProcessBundleDescriptor,
runner_context: "RayRunnerExecutionContext",
deferred_inputs: MutableMapping[str, List[bytes]],
):
"""Extract delayed applications from a bundle run.
A delayed application represents a user-initiated checkpoint, where user code
delays the consumption of a data element to checkpoint the previous elements
in a bundle.
"""
delayed_bundles = {}
for delayed_application in bundle_result.process_bundle.residual_roots:
# TODO(pabloem): Time delay needed for streaming. For now we'll ignore it.
# time_delay = delayed_application.requested_time_delay
source_transform = _get_source_transform_name(
_add_delayed_application_to_deferred_inputs(
process_bundle_descriptor,
delayed_application.application.transform_id,
delayed_application.application.input_id,
delayed_application,
deferred_inputs,
)

if source_transform not in delayed_bundles:
delayed_bundles[source_transform] = []
delayed_bundles[source_transform].append(
delayed_application.application.element
)

for consumer, data in delayed_bundles.items():
delayed_bundles[consumer] = [data]
def _add_residuals_and_channel_splits_to_deferred_inputs(
runner_context: "RayRunnerExecutionContext",
raw_inputs: Mapping[str, List[ray.ObjectRef]],
transform_buffer_coder: Mapping[str, typing.Tuple[bytes, str]],
process_bundle_descriptor: beam_fn_api_pb2.ProcessBundleDescriptor,
splits: List[beam_fn_api_pb2.ProcessBundleSplitResponse],
deferred_inputs: MutableMapping[str, List[bytes]],
):
prev_split_point = {} # transform id -> first residual offset
for split in splits:
for delayed_application in split.residual_roots:
_add_delayed_application_to_deferred_inputs(
process_bundle_descriptor,
delayed_application,
deferred_inputs,
)
for channel_split in split.channel_splits:
# Decode all input elements
byte_stream = b"".join(
(
element
for block in ray.get(raw_inputs[channel_split.transform_id])
for element in block
)
)
input_coder_id = transform_buffer_coder[channel_split.transform_id][1]
input_coder = runner_context.pipeline_context.coders[input_coder_id]

buffer_id = transform_buffer_coder[channel_split.transform_id][0]
if buffer_id.startswith(b"group:"):
coder_impl = coders.WindowedValueCoder(
coders.TupleCoder(
(
input_coder.wrapped_value_coder._coders[0],
input_coder.wrapped_value_coder._coders[1]._elem_coder,
)
),
input_coder.window_coder,
).get_impl()
else:
coder_impl = input_coder.get_impl()

all_elements = list(coder_impl.decode_all(byte_stream))

# split at first_residual_element index
end = prev_split_point.get(channel_split.transform_id, len(all_elements))
residual_elements = all_elements[channel_split.first_residual_element : end]
prev_split_point[
channel_split.transform_id
] = channel_split.first_residual_element

return delayed_bundles
if residual_elements:
encoded_residual = coder_impl.encode_all(residual_elements)

if channel_split.transform_id not in deferred_inputs:
deferred_inputs[channel_split.transform_id] = []
deferred_inputs[channel_split.transform_id].append(encoded_residual)


def _get_input_id(buffer_id, transform_name):
Expand Down Expand Up @@ -316,6 +407,94 @@ def _send_timers(
timer_out.close()


def _send_input_data(
worker_handler: worker_handlers.WorkerHandler,
input_data: Mapping[str, fn_execution.PartitionableBuffer],
process_bundle_id,
):
for transform_id, elements in input_data.items():
data_out = worker_handler.data_conn.output_stream(
process_bundle_id, transform_id
)
for byte_stream in elements:
data_out.write(byte_stream)
data_out.close()


def _run_split_manager(
runner_context: "RayRunnerExecutionContext",
worker_handler: worker_handlers.WorkerHandler,
split_manager,
inputs: Mapping[str, fn_execution.PartitionableBuffer],
transform_buffer_coder: Mapping[str, typing.Tuple[bytes, str]],
instruction_request,
split_results_buf: List[beam_fn_api_pb2.ProcessBundleSplitResponse],
):
read_transform_id, buffer_data = translations.only_element(inputs.items())
byte_stream = b"".join(buffer_data or [])
coder_id = transform_buffer_coder[read_transform_id][1]
coder_impl = runner_context.pipeline_context.coders[coder_id].get_impl()
num_elements = len(list(coder_impl.decode_all(byte_stream)))

# Start the split manager in case it wants to set any breakpoints.
split_manager_generator = split_manager(num_elements)
try:
split_fraction = next(split_manager_generator)
done = False
except StopIteration:
split_fraction = None
done = True

assert worker_handler is not None

# Execute the requested splits.
while not done:
if split_fraction is None:
split_result = None
else:
DesiredSplit = beam_fn_api_pb2.ProcessBundleSplitRequest.DesiredSplit
split_request = beam_fn_api_pb2.InstructionRequest(
process_bundle_split=beam_fn_api_pb2.ProcessBundleSplitRequest(
instruction_id=instruction_request.instruction_id,
desired_splits={
read_transform_id: DesiredSplit(
fraction_of_remainder=split_fraction,
estimated_input_elements=num_elements,
)
},
)
)
split_response = worker_handler.control_conn.push(
split_request
).get() # type: beam_fn_api_pb2.InstructionResponse
for t in (0.05, 0.1, 0.2):
if (
"Unknown process bundle" in split_response.error
or split_response.process_bundle_split
== beam_fn_api_pb2.ProcessBundleSplitResponse()
):
time.sleep(t)
split_response = worker_handler.control_conn.push(
split_request
).get()
if (
"Unknown process bundle" in split_response.error
or split_response.process_bundle_split
== beam_fn_api_pb2.ProcessBundleSplitResponse()
):
# It may have finished too fast.
split_result = None
elif split_response.error:
raise RuntimeError(split_response.error)
else:
split_result = split_response.process_bundle_split
split_results_buf.append(split_result)
try:
split_fraction = split_manager_generator.send(split_result)
except StopIteration:
break


@ray.remote
class _RayRunnerStats:
def __init__(self):
Expand Down
Loading

0 comments on commit e3d2d32

Please sign in to comment.