Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: rename immutable "RayBundleContextManager" to "RayBundleContext" #61

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ray_beam_runner/portability/context_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from ray_beam_runner.portability.execution import RayRunnerExecutionContext


class RayBundleContextManager:
class RayBundleContext:
def __init__(
self,
execution_context: RayRunnerExecutionContext,
Expand All @@ -63,7 +63,7 @@ def __reduce__(self):
data = (self.execution_context, self.stage)

def deserializer(args):
RayBundleContextManager(args[0], args[1])
RayBundleContext(args[0], args[1])

return (deserializer, data)

Expand Down
36 changes: 17 additions & 19 deletions ray_beam_runner/portability/ray_fn_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from apache_beam.portability.api import metrics_pb2

import ray
from ray_beam_runner.portability.context_management import RayBundleContextManager
from ray_beam_runner.portability.context_management import RayBundleContext
from ray_beam_runner.portability.execution import Bundle, _get_input_id
from ray_beam_runner.portability.execution import (
ray_execute_bundle,
Expand Down Expand Up @@ -223,7 +223,7 @@ def execute_pipeline(

try:
for stage in stages:
bundle_ctx = RayBundleContextManager(runner_execution_context, stage)
bundle_ctx = RayBundleContext(runner_execution_context, stage)
result = self._run_stage(runner_execution_context, bundle_ctx, queue)
monitoring_infos_by_stage[
bundle_ctx.stage.name
Expand All @@ -236,7 +236,7 @@ def execute_pipeline(
def _run_stage(
self,
runner_execution_context: RayRunnerExecutionContext,
bundle_context_manager: RayBundleContextManager,
bundle_context: RayBundleContext,
ready_bundles: collections.deque,
) -> beam_fn_api_pb2.InstructionResponse:

Expand All @@ -248,19 +248,19 @@ def _run_stage(
bundle_context_manager (execution.BundleContextManager): A description of
the stage to execute, and its context.
"""
bundle_context_manager.setup()
bundle_context.setup()
runner_execution_context.worker_manager.register_process_bundle_descriptor(
bundle_context_manager.process_bundle_descriptor
bundle_context.process_bundle_descriptor
)
input_timers: Mapping[
translations.TimerFamilyId, execution.PartitionableBuffer
] = {}

input_data = {
k: runner_execution_context.pcollection_buffers.get(
_get_input_id(bundle_context_manager.transform_to_buffer_coder[k][0], k)
_get_input_id(bundle_context.transform_to_buffer_coder[k][0], k)
)
for k in bundle_context_manager.transform_to_buffer_coder
for k in bundle_context.transform_to_buffer_coder
}

final_result = None # type: Optional[beam_fn_api_pb2.InstructionResponse]
Expand All @@ -273,7 +273,7 @@ def _run_stage(
bundle_outputs,
) = self._run_bundle(
runner_execution_context,
bundle_context_manager,
bundle_context,
Bundle(input_timers=input_timers, input_data=input_data),
)

Expand All @@ -298,7 +298,7 @@ def _run_stage(
# Store the required downstream side inputs into state so it is accessible
# for the worker when it runs bundles that consume this stage's output.
data_side_input = runner_execution_context.side_input_descriptors_by_stage.get(
bundle_context_manager.stage.name, {}
bundle_context.stage.name, {}
)
runner_execution_context.commit_side_inputs_to_state(data_side_input)

Expand All @@ -307,7 +307,7 @@ def _run_stage(
def _run_bundle(
self,
runner_execution_context: RayRunnerExecutionContext,
bundle_context_manager: RayBundleContextManager,
bundle_context: RayBundleContext,
input_bundle: Bundle,
) -> Tuple[
beam_fn_api_pb2.InstructionResponse,
Expand All @@ -320,13 +320,13 @@ def _run_bundle(
transform_to_buffer_coder,
data_output,
stage_timers,
) = bundle_context_manager.get_bundle_inputs_and_outputs()
) = bundle_context.get_bundle_inputs_and_outputs()

cache_token_generator = fn_runner.FnApiRunner.get_cache_token_generator(
static=False
)

process_bundle_descriptor = bundle_context_manager.process_bundle_descriptor
process_bundle_descriptor = bundle_context.process_bundle_descriptor

# TODO(pabloem): Are there two different IDs? the Bundle ID and PBD ID?
process_bundle_id = "bundle_%s" % process_bundle_descriptor.id
Expand Down Expand Up @@ -366,7 +366,7 @@ def _run_bundle(
(
watermarks_by_transform_and_timer_family,
newly_set_timers,
) = self._collect_written_timers(bundle_context_manager)
) = self._collect_written_timers(bundle_context)

# TODO(pabloem): Add support for splitting of results.

Expand All @@ -384,7 +384,7 @@ def _run_bundle(

@staticmethod
def _collect_written_timers(
bundle_context_manager: RayBundleContextManager,
bundle_context: RayBundleContext,
) -> Tuple[
Dict[translations.TimerFamilyId, timestamp.Timestamp],
Mapping[translations.TimerFamilyId, execution.PartitionableBuffer],
Expand All @@ -403,18 +403,16 @@ def _collect_written_timers(
timer_watermark_data = {}
newly_set_timers = {}

execution_context = bundle_context_manager.execution_context
execution_context = bundle_context.execution_context
buffer_manager = execution_context.pcollection_buffers

for (
transform_id,
timer_family_id,
), buffer_id in bundle_context_manager.stage_timers.items():
), buffer_id in bundle_context.stage_timers.items():
timer_buffer = buffer_manager.get(buffer_id)

coder_id = bundle_context_manager._timer_coder_ids[
(transform_id, timer_family_id)
]
coder_id = bundle_context._timer_coder_ids[(transform_id, timer_family_id)]

coder = execution_context.pipeline_context.coders[coder_id]
timer_coder_impl = coder.get_impl()
Expand Down