Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes #14: Refactoring ray portable runner #18

Merged
merged 6 commits into from
Jun 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ ignore =
I
N
avoid-escape = no

per-file-ignores =
*ray_runner_test.py: B008
6 changes: 6 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ jobs:
- name: Format
run: |
bash scripts/format.sh
- name: Install Ray Beam Runner
run: |
pip install -e .[test]
- name: Run Portability tests
run: |
pytest -r A ray_beam_runner/portability/ray_runner_test.py ray_beam_runner/portability/execution_test.py

LicenseCheck:
name: License Check
Expand Down
272 changes: 154 additions & 118 deletions ray_beam_runner/portability/context_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,137 +15,173 @@
# limitations under the License.
#
import typing
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple

from apache_beam.portability.api import beam_fn_api_pb2
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.portability.api import endpoints_pb2
from apache_beam.runners.portability.fn_api_runner import execution as fn_execution
from apache_beam.runners.portability.fn_api_runner import translations
from apache_beam.runners.portability.fn_api_runner import worker_handlers
from apache_beam.runners.portability.fn_api_runner.execution import PartitionableBuffer
from apache_beam.runners.portability.fn_api_runner.fn_runner import OutputTimers
from apache_beam.runners.portability.fn_api_runner.translations import DataOutput
from apache_beam.runners.portability.fn_api_runner.translations import TimerFamilyId
from apache_beam.runners.worker import bundle_processor
from apache_beam.utils import proto_utils

import ray
from ray_beam_runner.portability.execution import RayRunnerExecutionContext

class RayBundleContextManager:
ENCODED_IMPULSE_REFERENCE = ray.put([fn_execution.ENCODED_IMPULSE_VALUE])


def __init__(self,
execution_context: RayRunnerExecutionContext,
stage: translations.Stage,
) -> None:
self.execution_context = execution_context
self.stage = stage
# self.extract_bundle_inputs_and_outputs()
self.bundle_uid = self.execution_context.next_uid()

# Properties that are lazily initialized
self._process_bundle_descriptor = None # type: Optional[beam_fn_api_pb2.ProcessBundleDescriptor]
self._worker_handlers = None # type: Optional[List[worker_handlers.WorkerHandler]]
# a mapping of {(transform_id, timer_family_id): timer_coder_id}. The map
# is built after self._process_bundle_descriptor is initialized.
# This field can be used to tell whether current bundle has timers.
self._timer_coder_ids = None # type: Optional[Dict[Tuple[str, str], str]]

def __reduce__(self):
data = (self.execution_context,
self.stage)
deserializer = lambda args: RayBundleContextManager(args[0], args[1])
return (deserializer, data)

@property
def worker_handlers(self) -> List[worker_handlers.WorkerHandler]:
return []

def data_api_service_descriptor(self) -> Optional[endpoints_pb2.ApiServiceDescriptor]:
return endpoints_pb2.ApiServiceDescriptor(url='fake')

def state_api_service_descriptor(self) -> Optional[endpoints_pb2.ApiServiceDescriptor]:
return None

@property
def process_bundle_descriptor(self):
# type: () -> beam_fn_api_pb2.ProcessBundleDescriptor
if self._process_bundle_descriptor is None:
self._process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor.FromString(
self._build_process_bundle_descriptor())
self._timer_coder_ids = fn_execution.BundleContextManager._build_timer_coders_id_map(self)
return self._process_bundle_descriptor

def _build_process_bundle_descriptor(self):
# Cannot be invoked until *after* _extract_endpoints is called.
# Always populate the timer_api_service_descriptor.
pbd = beam_fn_api_pb2.ProcessBundleDescriptor(
id=self.bundle_uid,
transforms={
transform.unique_name: transform
for transform in self.stage.transforms
},
pcollections=dict(
self.execution_context.pipeline_components.pcollections.items()),
coders=dict(self.execution_context.pipeline_components.coders.items()),
windowing_strategies=dict(
self.execution_context.pipeline_components.windowing_strategies.
items()),
environments=dict(
self.execution_context.pipeline_components.environments.items()),
state_api_service_descriptor=self.state_api_service_descriptor(),
timer_api_service_descriptor=self.data_api_service_descriptor())

return pbd.SerializeToString()

def extract_bundle_inputs_and_outputs(self):
# type: () -> Tuple[Dict[str, PartitionableBuffer], DataOutput, Dict[TimerFamilyId, bytes]]

"""Returns maps of transform names to PCollection identifiers.

Also mutates IO stages to point to the data ApiServiceDescriptor.

Returns:
A tuple of (data_input, data_output, expected_timer_output) dictionaries.
`data_input` is a dictionary mapping (transform_name, output_name) to a
PCollection buffer; `data_output` is a dictionary mapping
(transform_name, output_name) to a PCollection ID.
`expected_timer_output` is a dictionary mapping transform_id and
timer family ID to a buffer id for timers.
"""
transform_to_buffer_coder: typing.Dict[str, typing.Tuple[bytes, str]] = {}
data_output = {} # type: DataOutput
expected_timer_output = {} # type: OutputTimers
for transform in self.stage.transforms:
if transform.spec.urn in (bundle_processor.DATA_INPUT_URN,
bundle_processor.DATA_OUTPUT_URN):
pcoll_id = transform.spec.payload
if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
coder_id = self.execution_context.data_channel_coders[translations.only_element(
transform.outputs.values())]
if pcoll_id == translations.IMPULSE_BUFFER:
buffer_actor = ray.get(self.execution_context.pcollection_buffers.get.remote(
transform.unique_name))
ray.get(buffer_actor.append.remote(fn_execution.ENCODED_IMPULSE_VALUE))
pcoll_id = transform.unique_name.encode('utf8')
else:
pass
transform_to_buffer_coder[transform.unique_name] = (
pcoll_id,
self.execution_context.safe_coders.get(coder_id, coder_id)
)
elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN:
data_output[transform.unique_name] = pcoll_id
coder_id = self.execution_context.data_channel_coders[translations.only_element(
transform.inputs.values())]
else:
raise NotImplementedError
# TODO(pabloem): Figure out when we DO and we DONT need this particular rewrite of coders.
data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id)
# data_spec.api_service_descriptor.url = 'fake'
transform.spec.payload = data_spec.SerializeToString()
elif transform.spec.urn in translations.PAR_DO_URNS:
payload = proto_utils.parse_Bytes(
transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
for timer_family_id in payload.timer_family_specs.keys():
expected_timer_output[(transform.unique_name, timer_family_id)] = (
translations.create_buffer_id(timer_family_id, 'timers'))
return transform_to_buffer_coder, data_output, expected_timer_output
class RayBundleContextManager:
def __init__(
self,
execution_context: RayRunnerExecutionContext,
stage: translations.Stage,
) -> None:
self.execution_context = execution_context
self.stage = stage
# self.extract_bundle_inputs_and_outputs()
self.bundle_uid = self.execution_context.next_uid()

# Properties that are lazily initialized
self._process_bundle_descriptor = (
None
) # type: Optional[beam_fn_api_pb2.ProcessBundleDescriptor]
self._worker_handlers = (
None
) # type: Optional[List[worker_handlers.WorkerHandler]]
# a mapping of {(transform_id, timer_family_id): timer_coder_id}. The map
# is built after self._process_bundle_descriptor is initialized.
# This field can be used to tell whether current bundle has timers.
self._timer_coder_ids = None # type: Optional[Dict[Tuple[str, str], str]]

def __reduce__(self):
data = (self.execution_context, self.stage)

def deserializer(args):
RayBundleContextManager(args[0], args[1])

return (deserializer, data)

@property
def worker_handlers(self) -> List[worker_handlers.WorkerHandler]:
return []

def data_api_service_descriptor(
self,
) -> Optional[endpoints_pb2.ApiServiceDescriptor]:
return endpoints_pb2.ApiServiceDescriptor(url="fake")

def state_api_service_descriptor(
self,
) -> Optional[endpoints_pb2.ApiServiceDescriptor]:
return None

@property
def process_bundle_descriptor(self) -> beam_fn_api_pb2.ProcessBundleDescriptor:
if self._process_bundle_descriptor is None:
self._process_bundle_descriptor = (
beam_fn_api_pb2.ProcessBundleDescriptor.FromString(
self._build_process_bundle_descriptor()
)
)
self._timer_coder_ids = (
fn_execution.BundleContextManager._build_timer_coders_id_map(self)
)
return self._process_bundle_descriptor

def _build_process_bundle_descriptor(self):
# Cannot be invoked until *after* _extract_endpoints is called.
# Always populate the timer_api_service_descriptor.
pbd = beam_fn_api_pb2.ProcessBundleDescriptor(
id=self.bundle_uid,
transforms={
transform.unique_name: transform for transform in self.stage.transforms
},
pcollections=dict(
self.execution_context.pipeline_components.pcollections.items()
),
coders=dict(self.execution_context.pipeline_components.coders.items()),
windowing_strategies=dict(
self.execution_context.pipeline_components.windowing_strategies.items()
),
environments=dict(
self.execution_context.pipeline_components.environments.items()
),
state_api_service_descriptor=self.state_api_service_descriptor(),
timer_api_service_descriptor=self.data_api_service_descriptor(),
)

return pbd.SerializeToString()

def get_bundle_inputs_and_outputs(
self,
) -> Tuple[Dict[str, PartitionableBuffer], DataOutput, Dict[TimerFamilyId, bytes]]:
"""Returns maps of transform names to PCollection identifiers.

Also mutates IO stages to point to the data ApiServiceDescriptor.

Returns:
A tuple of (data_input, data_output, expected_timer_output) dictionaries.
`data_input` is a dictionary mapping (transform_name, output_name) to a
PCollection buffer; `data_output` is a dictionary mapping
(transform_name, output_name) to a PCollection ID.
`expected_timer_output` is a dictionary mapping transform_id and
timer family ID to a buffer id for timers.
"""
return self.transform_to_buffer_coder, self.data_output, self.stage_timers

def setup(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason why this method is only called after initialization of the ContextManager?

transform_to_buffer_coder: typing.Dict[str, typing.Tuple[bytes, str]] = {}
data_output = {} # type: DataOutput
expected_timer_output = {} # type: OutputTimers
for transform in self.stage.transforms:
if transform.spec.urn in (
bundle_processor.DATA_INPUT_URN,
bundle_processor.DATA_OUTPUT_URN,
):
pcoll_id = transform.spec.payload
if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
coder_id = self.execution_context.data_channel_coders[
translations.only_element(transform.outputs.values())
]
if pcoll_id == translations.IMPULSE_BUFFER:
pcoll_id = transform.unique_name.encode("utf8")
self.execution_context.pcollection_buffers.put.remote(
pcoll_id, [ENCODED_IMPULSE_REFERENCE]
)
else:
pass
transform_to_buffer_coder[transform.unique_name] = (
pcoll_id,
self.execution_context.safe_coders.get(coder_id, coder_id),
)
elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN:
data_output[transform.unique_name] = pcoll_id
coder_id = self.execution_context.data_channel_coders[
translations.only_element(transform.inputs.values())
]
else:
raise NotImplementedError
data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id)
transform.spec.payload = data_spec.SerializeToString()
elif transform.spec.urn in translations.PAR_DO_URNS:
payload = proto_utils.parse_Bytes(
transform.spec.payload, beam_runner_api_pb2.ParDoPayload
)
for timer_family_id in payload.timer_family_specs.keys():
expected_timer_output[
(transform.unique_name, timer_family_id)
] = translations.create_buffer_id(timer_family_id, "timers")
self.transform_to_buffer_coder, self.data_output, self.stage_timers = (
transform_to_buffer_coder,
data_output,
expected_timer_output,
)
Loading