Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Starting to prototype the parallel execution for Ray Runner #57

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 81 additions & 30 deletions ray_beam_runner/portability/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,15 @@
from apache_beam.runners.worker import bundle_processor

from ray_beam_runner.portability.state import RayStateManager
from ray_beam_runner.portability.translations import StageTags

_LOGGER = logging.getLogger(__name__)


# TODO(pabloem): Stop hardcoding the number of blocks per task
BLOCKS_PER_TASK = 10


@ray.remote(num_returns="dynamic")
def ray_execute_bundle(
runner_context: "RayRunnerExecutionContext",
Expand All @@ -59,11 +64,22 @@ def ray_execute_bundle(
stage_timers: Mapping[translations.TimerFamilyId, bytes],
instruction_request_repr: Mapping[str, typing.Any],
dry_run=False,
stage_tags=None,
) -> Generator:
# generator returns:
# (serialized InstructionResponse, ouputs,
# repeat of pcoll, data,
# delayed applications, repeat of pcoll, data)
"""Execute a Beam bundle as a ray task.

:returns A `Generator` with the following values:
- serialized InstructionResponse,
- dictionary of timers
- dictionary of delayed applications
- count of output pcollections,
- repeat of
- pcoll name
- pcoll block count
- repeat of pcoll block
"""

stage_tags = stage_tags or set()

instruction_request = beam_fn_api_pb2.InstructionRequest(
instruction_id=instruction_request_repr["instruction_id"],
Expand All @@ -74,9 +90,19 @@ def ray_execute_bundle(
cache_tokens=[instruction_request_repr["cache_token"]],
),
)

# TODO(pabloem): CHECK THIS TO MAKE SURE IS GOOD
expects_group = any(k.startswith('group') for k in expected_outputs.keys())

output_buffers: Mapping[
typing.Union[str, translations.TimerFamilyId], list
] = collections.defaultdict(list)
str, typing.Union[KeyBlockBasedDataBuffer, RandomBlockBasedDataBuffer]
] = collections.defaultdict(
KeyBlockBasedDataBuffer if expects_group else RandomBlockBasedDataBuffer
)

output_timer_buffers: Mapping[
translations.TimerFamilyId, list] = collections.defaultdict(list)

process_bundle_id = instruction_request.instruction_id

worker_handler = _get_worker_handler(
Expand All @@ -99,26 +125,26 @@ def ray_execute_bundle(
data_out = worker_handler.data_conn.output_stream(
process_bundle_id, transform_id
)
for byte_stream in elements:
data_out.write(byte_stream)
data_out.close()

expect_reads: List[typing.Union[str, translations.TimerFamilyId]] = list(
expected_outputs.keys()
)
expect_reads.extend(list(stage_timers.keys()))
try:
for byte_stream in elements:
data_out.write(byte_stream)
data_out.close()
except:
# raise
import ray
ray.util.pdb.set_trace()

result_future = worker_handler.control_conn.push(instruction_request)

for output in worker_handler.data_conn.input_elements(
process_bundle_id,
expect_reads,
list(stage_timers.keys()) + list(expected_outputs.keys()),
abort_callback=lambda: (
result_future.is_done() and bool(result_future.get().error)
),
):
if isinstance(output, beam_fn_api_pb2.Elements.Timers) and not dry_run:
output_buffers[
output_timer_buffers[
stage_timers[(output.transform_id, output.timer_family_id)]
].append(output.timers)
if isinstance(output, beam_fn_api_pb2.Elements.Data) and not dry_run:
Expand All @@ -138,10 +164,8 @@ def ray_execute_bundle(

returns = [result.SerializeToString()]

returns.append(len(output_buffers))
for pcoll, buffer in output_buffers.items():
returns.append(pcoll)
returns.append(buffer)
# We pass output timers as a single full object, as these are smaller data
returns.append(output_timer_buffers)

# Now we collect all the deferred inputs remaining from bundle execution.
# Deferred inputs can be:
Expand All @@ -157,15 +181,46 @@ def ray_execute_bundle(
runner_context,
)

returns.append(len(delayed_applications))
for pcoll, buffer in delayed_applications.items():
# We pass delayed applications as a single full object, as these are smaller data
returns.append(delayed_applications)

returns.append(len(output_buffers))
for pcoll, buffer in output_buffers.items():
returns.append(pcoll)
returns.append(buffer)
returns.append(buffer.num_blocks())
for i in range(buffer.num_blocks()):
returns.append(buffer.blocks[i])

for ret in returns:
yield ret


class RandomBlockBasedDataBuffer:
def __init__(self):
self._num_blocks = BLOCKS_PER_TASK
self.blocks = [[] for _ in range(self._num_blocks)]
self._total_data = 0

def num_blocks(self):
return min(self._total_data, self._num_blocks)

def append(self, data):
self.blocks[self._total_data % len(self.blocks)].append(data)
self._total_data +=1


class KeyBlockBasedDataBuffer:
def __init__(self):
self._num_blocks = 1
self.blocks = [[] for _ in range(self._num_blocks)]

def num_blocks(self):
return 1

def append(self, data):
# TODO: Figure out how to get the Key for the data.
self.blocks[0].append(data)

def _get_source_transform_name(
process_bundle_descriptor: beam_fn_api_pb2.ProcessBundleDescriptor,
transform_id: str,
Expand Down Expand Up @@ -263,6 +318,7 @@ def _fetch_decode_data(
data_references: List[ray.ObjectRef],
):
"""Fetch a PCollection's data and decode it."""
logging.warning("pabloem - Buffer is %s" % buffer_id)
if buffer_id.startswith(b"group"):
_, pcoll_id = translations.split_buffer_id(buffer_id)
transform = runner_context.pipeline_components.transforms[pcoll_id]
Expand All @@ -288,15 +344,10 @@ def _fetch_decode_data(
windowing=apache_beam.Windowing.from_runner_api(windowing_strategy, None),
)
else:
buffer = fn_execution.ListBuffer(
coder_impl=runner_context.pipeline_context.coders[coder_id].get_impl()
)
buffer = []

for block in ray.get(data_references):
# TODO(pabloem): Stop using ListBuffer, and use different
# buffers to pass data to Beam.
for elm in block:
buffer.append(elm)
buffer.extend(block)
return buffer


Expand Down
116 changes: 73 additions & 43 deletions ray_beam_runner/portability/ray_fn_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# pytype: skip-file
# mypy: check-untyped-defs
import collections
import concurrent.futures
import copy
import logging
import typing
Expand Down Expand Up @@ -48,6 +49,7 @@
import ray
from ray_beam_runner.portability.context_management import RayBundleContextManager
from ray_beam_runner.portability.execution import Bundle, _get_input_id
from ray_beam_runner.portability import translations as ray_translations
from ray_beam_runner.portability.execution import (
ray_execute_bundle,
merge_stage_results,
Expand Down Expand Up @@ -170,7 +172,8 @@ def run_pipeline(
translations.pack_combiners,
translations.lift_combiners,
translations.expand_sdf,
translations.expand_gbk,
ray_translations.expand_gbk,
ray_translations.expand_reshuffle,
translations.sink_flattens,
translations.greedily_fuse,
translations.read_to_impulse,
Expand All @@ -183,6 +186,7 @@ def run_pipeline(
[
common_urns.primitives.FLATTEN.urn,
common_urns.primitives.GROUP_BY_KEY.urn,
common_urns.composites.RESHUFFLE.urn,
]
),
use_state_iterables=False,
Expand Down Expand Up @@ -248,6 +252,7 @@ def _run_stage(

final_result = None # type: Optional[beam_fn_api_pb2.InstructionResponse]

logging.warning("Executing stage %s", bundle_context_manager.stage.name)
while True:
(
last_result,
Expand Down Expand Up @@ -315,57 +320,82 @@ def _run_bundle(
process_bundle_id = "bundle_%s" % process_bundle_descriptor.id

pbd_id = process_bundle_descriptor.id
result_generator_ref = ray_execute_bundle.remote(
runner_execution_context,
input_bundle,
transform_to_buffer_coder,
data_output,
stage_timers,
instruction_request_repr={
"instruction_id": process_bundle_id,
"process_descriptor_id": pbd_id,
"cache_token": next(cache_token_generator),
},
)
result_generator = iter(ray.get(result_generator_ref))
result = beam_fn_api_pb2.InstructionResponse.FromString(
ray.get(next(result_generator))
)

output = []
num_outputs = ray.get(next(result_generator))
for _ in range(num_outputs):
pcoll = ray.get(next(result_generator))
data_ref = next(result_generator)
output.append(pcoll)
runner_execution_context.pcollection_buffers.put(pcoll, [data_ref])

delayed_applications = {}
num_delayed_applications = ray.get(next(result_generator))
for _ in range(num_delayed_applications):
pcoll = ray.get(next(result_generator))
data_ref = next(result_generator)
delayed_applications[pcoll] = data_ref
runner_execution_context.pcollection_buffers.put(pcoll, [data_ref])
input_data = input_bundle.input_data
result_generator_futures = []
if len(input_data) > 1:
raise RuntimeError(
"pabloem - Stage has multiple main input PCollections "
"which is unusual: %s"
% bundle_context_manager.stage.name)

input_id, obj_refs = list(input_data.items())[0]
logging.warning("pabloem - Running stage in PARALLEL AS WE HOPED - %d blocks", len(obj_refs))
# TODO(pabloem): This is an awful hack. HOW DO WE FREAKIN KEEP KEYED DATA TOGETHER?!
# TODO(pableom): DO GROUPING PER KEY.
if 'GroupByKey/Read' in input_id:
obj_refs = [obj_refs]
for i, obj_ref in enumerate(obj_refs):
result_generator_futures.append(ray_execute_bundle.remote(
runner_execution_context,
Bundle(input_timers=input_bundle.input_timers if i == 0 else {},
input_data={input_id: [obj_ref] if not isinstance(obj_ref, list) else obj_ref}),
transform_to_buffer_coder,
data_output,
stage_timers,
instruction_request_repr={
"instruction_id": process_bundle_id,
"process_descriptor_id": pbd_id,
"cache_token": next(cache_token_generator),
},
stage_tags=getattr(bundle_context_manager.stage, "tags", None)
))

final_result = None
final_output = set()
while True:
ready_results, result_generator_futures = ray.wait(result_generator_futures)
for ready_res in ready_results:
new_result, new_output, new_delayed_applications = self._fetch_execution_output(runner_execution_context, ready_res)
final_result = merge_stage_results(final_result, new_result) if final_result else new_result
final_output = final_output.union(new_output)
if not result_generator_futures:
break

(
watermarks_by_transform_and_timer_family,
newly_set_timers,
) = self._collect_written_timers(bundle_context_manager)

# TODO(pabloem): Add support for splitting of results.
# TODO: Set delayed applications somehow
return final_result, newly_set_timers, new_delayed_applications, final_output

def _fetch_execution_output(self, runner_execution_context: RayRunnerExecutionContext, result_generator_ref):
result_generator = iter(ray.get(result_generator_ref))
response_str = ray.get(next(result_generator))
result = beam_fn_api_pb2.InstructionResponse.FromString(
response_str
)

output_timers = ray.get(next(result_generator))
delayed_applications = ray.get(next(result_generator))

for timer_id, timer_data in output_timers.items():
runner_execution_context.pcollection_buffers.put(timer_id, timer_data)
for pcoll, data_ref in delayed_applications.items():
runner_execution_context.pcollection_buffers.put(pcoll, [data_ref])

# After collecting deferred inputs, we 'pad' the structure with empty
# buffers for other expected inputs.
# if deferred_inputs or newly_set_timers:
# # The worker will be waiting on these inputs as well.
# for other_input in data_input:
# if other_input not in deferred_inputs:
# deferred_inputs[other_input] = ListBuffer(
# coder_impl=bundle_context_manager.get_input_coder_impl(
# other_input))
output = []
num_outputs = ray.get(next(result_generator))
for _1 in range(num_outputs):
pcoll = ray.get(next(result_generator))
output.append(pcoll)
blocks_per_pcoll = ray.get(next(result_generator))
for _2 in range(blocks_per_pcoll):
data_ref = next(result_generator)
runner_execution_context.pcollection_buffers.put(pcoll, [data_ref])

return result, newly_set_timers, delayed_applications, output
return result, output, delayed_applications

@staticmethod
def _collect_written_timers(
Expand Down
Loading