From 1e4d36577e3f18018d7e9818c5f2782e82f84c07 Mon Sep 17 00:00:00 2001 From: Pablo E Date: Mon, 13 Jun 2022 14:49:10 -0700 Subject: [PATCH] Fix formatting --- .../portability/context_management.py | 273 +- ray_beam_runner/portability/execution.py | 746 ++-- ray_beam_runner/portability/execution_test.py | 65 +- ray_beam_runner/portability/ray_fn_runner.py | 545 +-- .../portability/ray_runner_test.py | 3720 +++++++++-------- ray_beam_runner/portability/state.py | 142 +- 6 files changed, 2908 insertions(+), 2583 deletions(-) diff --git a/ray_beam_runner/portability/context_management.py b/ray_beam_runner/portability/context_management.py index 09b39ee..7870b2c 100644 --- a/ray_beam_runner/portability/context_management.py +++ b/ray_beam_runner/portability/context_management.py @@ -14,10 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import logging import typing +from typing import Dict from typing import List from typing import Optional +from typing import Tuple from apache_beam.portability.api import beam_fn_api_pb2 from apache_beam.portability.api import beam_runner_api_pb2 @@ -25,6 +26,10 @@ from apache_beam.runners.portability.fn_api_runner import execution as fn_execution from apache_beam.runners.portability.fn_api_runner import translations from apache_beam.runners.portability.fn_api_runner import worker_handlers +from apache_beam.runners.portability.fn_api_runner.execution import PartitionableBuffer +from apache_beam.runners.portability.fn_api_runner.execution import OutputTimers +from apache_beam.runners.portability.fn_api_runner.translations import DataOutput +from apache_beam.runners.portability.fn_api_runner.translations import TimerFamilyId from apache_beam.runners.worker import bundle_processor from apache_beam.utils import proto_utils @@ -35,124 +40,148 @@ class RayBundleContextManager: - - def __init__(self, - execution_context: RayRunnerExecutionContext, - stage: translations.Stage, - ) -> None: - self.execution_context = execution_context - self.stage = stage - # self.extract_bundle_inputs_and_outputs() - self.bundle_uid = self.execution_context.next_uid() - - # Properties that are lazily initialized - self._process_bundle_descriptor = None # type: Optional[beam_fn_api_pb2.ProcessBundleDescriptor] - self._worker_handlers = None # type: Optional[List[worker_handlers.WorkerHandler]] - # a mapping of {(transform_id, timer_family_id): timer_coder_id}. The map - # is built after self._process_bundle_descriptor is initialized. - # This field can be used to tell whether current bundle has timers. - self._timer_coder_ids = None # type: Optional[Dict[Tuple[str, str], str]] - - def __reduce__(self): - data = (self.execution_context, - self.stage) - deserializer = lambda args: RayBundleContextManager(args[0], args[1]) - return (deserializer, data) - - @property - def worker_handlers(self) -> List[worker_handlers.WorkerHandler]: - return [] - - def data_api_service_descriptor(self) -> Optional[endpoints_pb2.ApiServiceDescriptor]: - return endpoints_pb2.ApiServiceDescriptor(url='fake') - - def state_api_service_descriptor(self) -> Optional[endpoints_pb2.ApiServiceDescriptor]: - return None - - @property - def process_bundle_descriptor(self): - # type: () -> beam_fn_api_pb2.ProcessBundleDescriptor - if self._process_bundle_descriptor is None: - self._process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor.FromString( - self._build_process_bundle_descriptor()) - self._timer_coder_ids = fn_execution.BundleContextManager._build_timer_coders_id_map(self) - return self._process_bundle_descriptor - - def _build_process_bundle_descriptor(self): - # Cannot be invoked until *after* _extract_endpoints is called. - # Always populate the timer_api_service_descriptor. - pbd = beam_fn_api_pb2.ProcessBundleDescriptor( - id=self.bundle_uid, - transforms={ - transform.unique_name: transform - for transform in self.stage.transforms - }, - pcollections=dict( - self.execution_context.pipeline_components.pcollections.items()), - coders=dict(self.execution_context.pipeline_components.coders.items()), - windowing_strategies=dict( - self.execution_context.pipeline_components.windowing_strategies. - items()), - environments=dict( - self.execution_context.pipeline_components.environments.items()), - state_api_service_descriptor=self.state_api_service_descriptor(), - timer_api_service_descriptor=self.data_api_service_descriptor()) - - return pbd.SerializeToString() - - def get_bundle_inputs_and_outputs(self): - # type: () -> Tuple[Dict[str, PartitionableBuffer], DataOutput, Dict[TimerFamilyId, bytes]] - - """Returns maps of transform names to PCollection identifiers. - - Also mutates IO stages to point to the data ApiServiceDescriptor. - - Returns: - A tuple of (data_input, data_output, expected_timer_output) dictionaries. - `data_input` is a dictionary mapping (transform_name, output_name) to a - PCollection buffer; `data_output` is a dictionary mapping - (transform_name, output_name) to a PCollection ID. - `expected_timer_output` is a dictionary mapping transform_id and - timer family ID to a buffer id for timers. - """ - return self.transform_to_buffer_coder, self.data_output, self.stage_timers - - def setup(self): - transform_to_buffer_coder: typing.Dict[str, typing.Tuple[bytes, str]] = {} - data_output = {} # type: DataOutput - expected_timer_output = {} # type: OutputTimers - for transform in self.stage.transforms: - if transform.spec.urn in (bundle_processor.DATA_INPUT_URN, - bundle_processor.DATA_OUTPUT_URN): - pcoll_id = transform.spec.payload - if transform.spec.urn == bundle_processor.DATA_INPUT_URN: - coder_id = self.execution_context.data_channel_coders[translations.only_element( - transform.outputs.values())] - if pcoll_id == translations.IMPULSE_BUFFER: - pcoll_id = transform.unique_name.encode('utf8') - self.execution_context.pcollection_buffers.put.remote( - pcoll_id, [ENCODED_IMPULSE_REFERENCE]) - else: - pass - transform_to_buffer_coder[transform.unique_name] = ( - pcoll_id, - self.execution_context.safe_coders.get(coder_id, coder_id) - ) - elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN: - data_output[transform.unique_name] = pcoll_id - coder_id = self.execution_context.data_channel_coders[translations.only_element( - transform.inputs.values())] - else: - raise NotImplementedError - # TODO(pabloem): Figure out when we DO and we DONT need this particular rewrite of coders. - data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id) - # data_spec.api_service_descriptor.url = 'fake' - transform.spec.payload = data_spec.SerializeToString() - elif transform.spec.urn in translations.PAR_DO_URNS: - payload = proto_utils.parse_Bytes( - transform.spec.payload, beam_runner_api_pb2.ParDoPayload) - for timer_family_id in payload.timer_family_specs.keys(): - expected_timer_output[(transform.unique_name, timer_family_id)] = ( - translations.create_buffer_id(timer_family_id, 'timers')) - self.transform_to_buffer_coder, self.data_output, self.stage_timers = ( - transform_to_buffer_coder, data_output, expected_timer_output) \ No newline at end of file + def __init__( + self, + execution_context: RayRunnerExecutionContext, + stage: translations.Stage, + ) -> None: + self.execution_context = execution_context + self.stage = stage + # self.extract_bundle_inputs_and_outputs() + self.bundle_uid = self.execution_context.next_uid() + + # Properties that are lazily initialized + self._process_bundle_descriptor = ( + None + ) # type: Optional[beam_fn_api_pb2.ProcessBundleDescriptor] + self._worker_handlers = ( + None + ) # type: Optional[List[worker_handlers.WorkerHandler]] + # a mapping of {(transform_id, timer_family_id): timer_coder_id}. The map + # is built after self._process_bundle_descriptor is initialized. + # This field can be used to tell whether current bundle has timers. + self._timer_coder_ids = None # type: Optional[Dict[Tuple[str, str], str]] + + def __reduce__(self): + data = (self.execution_context, self.stage) + + def deserializer(args): + RayBundleContextManager(args[0], args[1]) + + return (deserializer, data) + + @property + def worker_handlers(self) -> List[worker_handlers.WorkerHandler]: + return [] + + def data_api_service_descriptor( + self, + ) -> Optional[endpoints_pb2.ApiServiceDescriptor]: + return endpoints_pb2.ApiServiceDescriptor(url="fake") + + def state_api_service_descriptor( + self, + ) -> Optional[endpoints_pb2.ApiServiceDescriptor]: + return None + + @property + def process_bundle_descriptor(self) -> beam_fn_api_pb2.ProcessBundleDescriptor: + if self._process_bundle_descriptor is None: + self._process_bundle_descriptor = ( + beam_fn_api_pb2.ProcessBundleDescriptor.FromString( + self._build_process_bundle_descriptor() + ) + ) + self._timer_coder_ids = ( + fn_execution.BundleContextManager._build_timer_coders_id_map(self) + ) + return self._process_bundle_descriptor + + def _build_process_bundle_descriptor(self): + # Cannot be invoked until *after* _extract_endpoints is called. + # Always populate the timer_api_service_descriptor. + pbd = beam_fn_api_pb2.ProcessBundleDescriptor( + id=self.bundle_uid, + transforms={ + transform.unique_name: transform for transform in self.stage.transforms + }, + pcollections=dict( + self.execution_context.pipeline_components.pcollections.items() + ), + coders=dict(self.execution_context.pipeline_components.coders.items()), + windowing_strategies=dict( + self.execution_context.pipeline_components.windowing_strategies.items() + ), + environments=dict( + self.execution_context.pipeline_components.environments.items() + ), + state_api_service_descriptor=self.state_api_service_descriptor(), + timer_api_service_descriptor=self.data_api_service_descriptor(), + ) + + return pbd.SerializeToString() + + def get_bundle_inputs_and_outputs( + self, + ) -> Tuple[Dict[str, PartitionableBuffer], DataOutput, Dict[TimerFamilyId, bytes]]: + """Returns maps of transform names to PCollection identifiers. + + Also mutates IO stages to point to the data ApiServiceDescriptor. + + Returns: + A tuple of (data_input, data_output, expected_timer_output) dictionaries. + `data_input` is a dictionary mapping (transform_name, output_name) to a + PCollection buffer; `data_output` is a dictionary mapping + (transform_name, output_name) to a PCollection ID. + `expected_timer_output` is a dictionary mapping transform_id and + timer family ID to a buffer id for timers. + """ + return self.transform_to_buffer_coder, self.data_output, self.stage_timers + + def setup(self): + transform_to_buffer_coder: typing.Dict[str, typing.Tuple[bytes, str]] = {} + data_output = {} # type: DataOutput + expected_timer_output = {} # type: OutputTimers + for transform in self.stage.transforms: + if transform.spec.urn in ( + bundle_processor.DATA_INPUT_URN, + bundle_processor.DATA_OUTPUT_URN, + ): + pcoll_id = transform.spec.payload + if transform.spec.urn == bundle_processor.DATA_INPUT_URN: + coder_id = self.execution_context.data_channel_coders[ + translations.only_element(transform.outputs.values()) + ] + if pcoll_id == translations.IMPULSE_BUFFER: + pcoll_id = transform.unique_name.encode("utf8") + self.execution_context.pcollection_buffers.put.remote( + pcoll_id, [ENCODED_IMPULSE_REFERENCE] + ) + else: + pass + transform_to_buffer_coder[transform.unique_name] = ( + pcoll_id, + self.execution_context.safe_coders.get(coder_id, coder_id), + ) + elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN: + data_output[transform.unique_name] = pcoll_id + coder_id = self.execution_context.data_channel_coders[ + translations.only_element(transform.inputs.values()) + ] + else: + raise NotImplementedError + data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id) + transform.spec.payload = data_spec.SerializeToString() + elif transform.spec.urn in translations.PAR_DO_URNS: + payload = proto_utils.parse_Bytes( + transform.spec.payload, beam_runner_api_pb2.ParDoPayload + ) + for timer_family_id in payload.timer_family_specs.keys(): + expected_timer_output[ + (transform.unique_name, timer_family_id) + ] = translations.create_buffer_id(timer_family_id, "timers") + self.transform_to_buffer_coder, self.data_output, self.stage_timers = ( + transform_to_buffer_coder, + data_output, + expected_timer_output, + ) diff --git a/ray_beam_runner/portability/execution.py b/ray_beam_runner/portability/execution.py index 08d1205..5e0a551 100644 --- a/ray_beam_runner/portability/execution.py +++ b/ray_beam_runner/portability/execution.py @@ -53,8 +53,8 @@ @ray.remote def ray_execute_bundle( - runner_context: 'RayRunnerExecutionContext', - input_bundle: 'Bundle', + runner_context: "RayRunnerExecutionContext", + input_bundle: "Bundle", transform_buffer_coder: Mapping[str, typing.Tuple[bytes, str]], expected_outputs: translations.DataOutput, stage_timers: Mapping[translations.TimerFamilyId, bytes], @@ -62,376 +62,466 @@ def ray_execute_bundle( dry_run=False, ) -> Tuple[str, List[Any], Mapping[str, ray.ObjectRef]]: - instruction_request = beam_fn_api_pb2.InstructionRequest( - instruction_id=instruction_request_repr['instruction_id'], - process_bundle=beam_fn_api_pb2.ProcessBundleRequest( - process_bundle_descriptor_id=instruction_request_repr['process_descriptor_id'], - cache_tokens=[instruction_request_repr['cache_token']])) - output_buffers: Mapping[typing.Union[str, translations.TimerFamilyId], list] = collections.defaultdict(list) - process_bundle_id = instruction_request.instruction_id - - worker_handler = _get_worker_handler( - runner_context, instruction_request_repr['process_descriptor_id']) - - _send_timers(worker_handler, input_bundle, stage_timers, process_bundle_id) - - input_data = { - k: _fetch_decode_data(runner_context, _get_input_id(transform_buffer_coder[k][0], k), transform_buffer_coder[k][1], objrefs) - for k, objrefs in input_bundle.input_data.items() - } - - for transform_id, elements in input_data.items(): - data_out = worker_handler.data_conn.output_stream( - process_bundle_id, transform_id) - for byte_stream in elements: - data_out.write(byte_stream) - data_out.close() - - expect_reads: List[typing.Union[str, translations.TimerFamilyId]] = list(expected_outputs.keys()) - expect_reads.extend(list(stage_timers.keys())) - - result_future = worker_handler.control_conn.push(instruction_request) - - for output in worker_handler.data_conn.input_elements( - process_bundle_id, - expect_reads, - abort_callback=lambda: - (result_future.is_done() and bool(result_future.get().error))): - if isinstance(output, beam_fn_api_pb2.Elements.Timers) and not dry_run: - output_buffers[expected_outputs[(output.transform_id, output.timer_family_id)]].append(output.data) - if isinstance(output, beam_fn_api_pb2.Elements.Data) and not dry_run: - output_buffers[expected_outputs[output.transform_id]].append(output.data) - - for pcoll, buffer in output_buffers.items(): - objrefs = [ray.put(buffer)] - runner_context.pcollection_buffers.put.remote(pcoll, objrefs) - output_buffers[pcoll] = objrefs - - result: beam_fn_api_pb2.InstructionResponse = result_future.get() - - # Now we collect all the deferred inputs remaining from bundle execution. - # Deferred inputs can be: - # - timers - # - SDK-initiated deferred applications of root elements - # - # TODO: Runner-initiated deferred applications of root elements - delayed_applications = _retrieve_delayed_applications( - result, - runner_context.worker_manager.process_bundle_descriptor( - instruction_request_repr['process_descriptor_id']), - runner_context) - - return result.SerializeToString(), list(output_buffers.keys()), delayed_applications + instruction_request = beam_fn_api_pb2.InstructionRequest( + instruction_id=instruction_request_repr["instruction_id"], + process_bundle=beam_fn_api_pb2.ProcessBundleRequest( + process_bundle_descriptor_id=instruction_request_repr[ + "process_descriptor_id" + ], + cache_tokens=[instruction_request_repr["cache_token"]], + ), + ) + output_buffers: Mapping[ + typing.Union[str, translations.TimerFamilyId], list + ] = collections.defaultdict(list) + process_bundle_id = instruction_request.instruction_id + + worker_handler = _get_worker_handler( + runner_context, instruction_request_repr["process_descriptor_id"] + ) + + _send_timers(worker_handler, input_bundle, stage_timers, process_bundle_id) + + input_data = { + k: _fetch_decode_data( + runner_context, + _get_input_id(transform_buffer_coder[k][0], k), + transform_buffer_coder[k][1], + objrefs, + ) + for k, objrefs in input_bundle.input_data.items() + } + + for transform_id, elements in input_data.items(): + data_out = worker_handler.data_conn.output_stream( + process_bundle_id, transform_id + ) + for byte_stream in elements: + data_out.write(byte_stream) + data_out.close() + + expect_reads: List[typing.Union[str, translations.TimerFamilyId]] = list( + expected_outputs.keys() + ) + expect_reads.extend(list(stage_timers.keys())) + + result_future = worker_handler.control_conn.push(instruction_request) + + for output in worker_handler.data_conn.input_elements( + process_bundle_id, + expect_reads, + abort_callback=lambda: ( + result_future.is_done() and bool(result_future.get().error) + ), + ): + if isinstance(output, beam_fn_api_pb2.Elements.Timers) and not dry_run: + output_buffers[ + expected_outputs[(output.transform_id, output.timer_family_id)] + ].append(output.data) + if isinstance(output, beam_fn_api_pb2.Elements.Data) and not dry_run: + output_buffers[expected_outputs[output.transform_id]].append(output.data) + + for pcoll, buffer in output_buffers.items(): + objrefs = [ray.put(buffer)] + runner_context.pcollection_buffers.put.remote(pcoll, objrefs) + output_buffers[pcoll] = objrefs + + result: beam_fn_api_pb2.InstructionResponse = result_future.get() + + # Now we collect all the deferred inputs remaining from bundle execution. + # Deferred inputs can be: + # - timers + # - SDK-initiated deferred applications of root elements + # - # TODO: Runner-initiated deferred applications of root elements + delayed_applications = _retrieve_delayed_applications( + result, + runner_context.worker_manager.process_bundle_descriptor( + instruction_request_repr["process_descriptor_id"] + ), + runner_context, + ) + + return result.SerializeToString(), list(output_buffers.keys()), delayed_applications def _retrieve_delayed_applications( bundle_result: beam_fn_api_pb2.InstructionResponse, process_bundle_descriptor: beam_fn_api_pb2.ProcessBundleDescriptor, - runner_context: 'RayRunnerExecutionContext'): - """Extract delayed applications from a bundle run. - - A delayed application represents a user-initiated checkpoint, where user code - delays the consumption of a data element to checkpoint the previous elements - in a bundle. - """ - delayed_bundles = {} - for delayed_application in bundle_result.process_bundle.residual_roots: - # TODO(pabloem): Time delay needed for streaming. For now we'll ignore it. - time_delay = delayed_application.requested_time_delay - transform = process_bundle_descriptor.transforms[ - delayed_application.application.transform_id] - pcoll_name = transform.inputs[delayed_application.application.input_id] - - consumer_transform = translations.only_element([ - read_id for read_id, proto in process_bundle_descriptor.transforms.items() - if proto.spec.urn == bundle_processor.DATA_INPUT_URN - and pcoll_name in proto.outputs.values()]) - if consumer_transform not in delayed_bundles: - delayed_bundles[consumer_transform] = [] - delayed_bundles[consumer_transform].append(delayed_application.application.element) - - for consumer, data in delayed_bundles.items(): - ref = ray.put([data]) - runner_context.pcollection_buffers.put.remote(consumer, [ref]) - delayed_bundles[consumer] = ref - - return delayed_bundles + runner_context: "RayRunnerExecutionContext", +): + """Extract delayed applications from a bundle run. + + A delayed application represents a user-initiated checkpoint, where user code + delays the consumption of a data element to checkpoint the previous elements + in a bundle. + """ + delayed_bundles = {} + for delayed_application in bundle_result.process_bundle.residual_roots: + # TODO(pabloem): Time delay needed for streaming. For now we'll ignore it. + # time_delay = delayed_application.requested_time_delay + transform = process_bundle_descriptor.transforms[ + delayed_application.application.transform_id + ] + pcoll_name = transform.inputs[delayed_application.application.input_id] + + consumer_transform = translations.only_element( + [ + read_id + for read_id, proto in process_bundle_descriptor.transforms.items() + if proto.spec.urn == bundle_processor.DATA_INPUT_URN + and pcoll_name in proto.outputs.values() + ] + ) + if consumer_transform not in delayed_bundles: + delayed_bundles[consumer_transform] = [] + delayed_bundles[consumer_transform].append( + delayed_application.application.element + ) + + for consumer, data in delayed_bundles.items(): + ref = ray.put([data]) + runner_context.pcollection_buffers.put.remote(consumer, [ref]) + delayed_bundles[consumer] = ref + + return delayed_bundles def _get_input_id(buffer_id, transform_name): - """Get the 'buffer_id' for the input data we're retrieving. - - For most data, the buffer ID is as expected, but for IMPULSE readers, the - buffer ID is the consumer name. - """ - if isinstance(buffer_id, bytes) and ( - buffer_id.startswith(b'materialize') or buffer_id.startswith(b'timer') or buffer_id.startswith(b'group')): - buffer_id = buffer_id - else: - buffer_id = transform_name.encode('ascii') - return buffer_id - - -def _fetch_decode_data(runner_context: 'RayRunnerExecutionContext', buffer_id: bytes, coder_id: str, data_references: List[ray.ObjectRef]): - """Fetch a PCollection's data and decode it.""" - if buffer_id.startswith(b'group'): - _, pcoll_id = translations.split_buffer_id(buffer_id) - transform = runner_context.pipeline_components.transforms[pcoll_id] - out_pcoll = runner_context.pipeline_components.pcollections[translations.only_element(transform.outputs.values())] - windowing_strategy = runner_context.pipeline_components.windowing_strategies[out_pcoll.windowing_strategy_id] - postcoder = runner_context.pipeline_context.coders[coder_id] - precoder = coders.WindowedValueCoder( - coders.TupleCoder(( - postcoder.wrapped_value_coder._coders[0], - postcoder.wrapped_value_coder._coders[1]._elem_coder - )), - postcoder.window_coder) - buffer = fn_execution.GroupingBuffer( - pre_grouped_coder=precoder, - post_grouped_coder=postcoder, - windowing=apache_beam.Windowing.from_runner_api(windowing_strategy, None), - ) - else: - buffer = fn_execution.ListBuffer(coder_impl=runner_context.pipeline_context.coders[coder_id].get_impl()) - - for block in ray.get(data_references): - # TODO(pabloem): Stop using ListBuffer, and use different buffers to pass data to Beam - for elm in block: - buffer.append(elm) - return buffer - - -def _send_timers(worker_handler: worker_handlers.WorkerHandler, - input_bundle: 'Bundle', - stage_timers: Mapping[translations.TimerFamilyId, bytes], - process_bundle_id) -> None: - """Pass timers to the worker for processing.""" - for transform_id, timer_family_id in stage_timers.keys(): - timer_out = worker_handler.data_conn.output_timer_stream( - process_bundle_id, transform_id, timer_family_id) - for timer in input_bundle.input_timers.get((transform_id, timer_family_id), []): - timer_out.write(timer) - timer_out.close() + """Get the 'buffer_id' for the input data we're retrieving. + + For most data, the buffer ID is as expected, but for IMPULSE readers, the + buffer ID is the consumer name. + """ + if isinstance(buffer_id, bytes) and ( + buffer_id.startswith(b"materialize") + or buffer_id.startswith(b"timer") + or buffer_id.startswith(b"group") + ): + buffer_id = buffer_id + else: + buffer_id = transform_name.encode("ascii") + return buffer_id + + +def _fetch_decode_data( + runner_context: "RayRunnerExecutionContext", + buffer_id: bytes, + coder_id: str, + data_references: List[ray.ObjectRef], +): + """Fetch a PCollection's data and decode it.""" + if buffer_id.startswith(b"group"): + _, pcoll_id = translations.split_buffer_id(buffer_id) + transform = runner_context.pipeline_components.transforms[pcoll_id] + out_pcoll = runner_context.pipeline_components.pcollections[ + translations.only_element(transform.outputs.values()) + ] + windowing_strategy = runner_context.pipeline_components.windowing_strategies[ + out_pcoll.windowing_strategy_id + ] + postcoder = runner_context.pipeline_context.coders[coder_id] + precoder = coders.WindowedValueCoder( + coders.TupleCoder( + ( + postcoder.wrapped_value_coder._coders[0], + postcoder.wrapped_value_coder._coders[1]._elem_coder, + ) + ), + postcoder.window_coder, + ) + buffer = fn_execution.GroupingBuffer( + pre_grouped_coder=precoder, + post_grouped_coder=postcoder, + windowing=apache_beam.Windowing.from_runner_api(windowing_strategy, None), + ) + else: + buffer = fn_execution.ListBuffer( + coder_impl=runner_context.pipeline_context.coders[coder_id].get_impl() + ) + + for block in ray.get(data_references): + # TODO(pabloem): Stop using ListBuffer, and use different + # buffers to pass data to Beam. + for elm in block: + buffer.append(elm) + return buffer + + +def _send_timers( + worker_handler: worker_handlers.WorkerHandler, + input_bundle: "Bundle", + stage_timers: Mapping[translations.TimerFamilyId, bytes], + process_bundle_id, +) -> None: + """Pass timers to the worker for processing.""" + for transform_id, timer_family_id in stage_timers.keys(): + timer_out = worker_handler.data_conn.output_timer_stream( + process_bundle_id, transform_id, timer_family_id + ) + for timer in input_bundle.input_timers.get((transform_id, timer_family_id), []): + timer_out.write(timer) + timer_out.close() + @ray.remote class _RayRunnerStats: - def __init__(self): - self._bundle_uid = 0 + def __init__(self): + self._bundle_uid = 0 - def next_bundle(self): - self._bundle_uid += 1 - return self._bundle_uid + def next_bundle(self): + self._bundle_uid += 1 + return self._bundle_uid class RayWorkerHandlerManager: - def __init__(self): - self._process_bundle_descriptors = {} + def __init__(self): + self._process_bundle_descriptors = {} - def register_process_bundle_descriptor(self, process_bundle_descriptor: beam_fn_api_pb2.ProcessBundleDescriptor): - self._process_bundle_descriptors[process_bundle_descriptor.id] = process_bundle_descriptor.SerializeToString() + def register_process_bundle_descriptor( + self, process_bundle_descriptor: beam_fn_api_pb2.ProcessBundleDescriptor + ): + self._process_bundle_descriptors[ + process_bundle_descriptor.id + ] = process_bundle_descriptor.SerializeToString() - def process_bundle_descriptor(self, id): - pbd = self._process_bundle_descriptors[id] - if isinstance(pbd, beam_fn_api_pb2.ProcessBundleDescriptor): - return pbd - else: - return beam_fn_api_pb2.ProcessBundleDescriptor.FromString(pbd) + def process_bundle_descriptor(self, id): + pbd = self._process_bundle_descriptors[id] + if isinstance(pbd, beam_fn_api_pb2.ProcessBundleDescriptor): + return pbd + else: + return beam_fn_api_pb2.ProcessBundleDescriptor.FromString(pbd) class RayStage(translations.Stage): - def __reduce__(self): - data = ( - self.name, - [t.SerializeToString() for t in self.transforms], - self.downstream_side_inputs, - [], # self.must_follow, - self.parent, - self.environment, - self.forced_root) - deserializer = lambda *args: RayStage( - args[0], - [beam_runner_api_pb2.PTransform.FromString(s) for s in args[1]], - args[2], - args[3], - args[4], - args[5], - args[6],) - return (deserializer, data) - - @staticmethod - def from_Stage(stage: translations.Stage): - return RayStage( - stage.name, - stage.transforms, - stage.downstream_side_inputs, - # stage.must_follow, - [], - stage.parent, - stage.environment, - stage.forced_root, - ) + def __reduce__(self): + data = ( + self.name, + [t.SerializeToString() for t in self.transforms], + self.downstream_side_inputs, + [], # self.must_follow, + self.parent, + self.environment, + self.forced_root, + ) + + def deserializer(*args): + return RayStage( + args[0], + [beam_runner_api_pb2.PTransform.FromString(s) for s in args[1]], + args[2], + args[3], + args[4], + args[5], + args[6], + ) + + return (deserializer, data) + + @staticmethod + def from_Stage(stage: translations.Stage): + return RayStage( + stage.name, + stage.transforms, + stage.downstream_side_inputs, + # stage.must_follow, + [], + stage.parent, + stage.environment, + stage.forced_root, + ) @ray.remote class PcollectionBufferManager: - def __init__(self): - self.buffers = collections.defaultdict(list) + def __init__(self): + self.buffers = collections.defaultdict(list) - def put(self, pcoll, data_refs: List[ray.ObjectRef]): - self.buffers[pcoll].extend(data_refs) + def put(self, pcoll, data_refs: List[ray.ObjectRef]): + self.buffers[pcoll].extend(data_refs) - def get(self, pcoll) -> List[ray.ObjectRef]: - return self.buffers[pcoll] + def get(self, pcoll) -> List[ray.ObjectRef]: + return self.buffers[pcoll] @ray.remote class RayWatermarkManager(watermark_manager.WatermarkManager): - def __init__(self): - # the original WatermarkManager performs a lot of computation - # in its __init__ method. Because Ray calls __init__ whenever - # it deserializes an object, we'll move its setup elsewhere. - self._initialized = False - self._pcollections_by_name = {} - self._stages_by_name = {} - - def setup(self, stages): - if self._initialized: - return - logging.debug('initialized the RayWatermarkManager') - self._initialized = True - watermark_manager.WatermarkManager.setup(self, stages) + def __init__(self): + # the original WatermarkManager performs a lot of computation + # in its __init__ method. Because Ray calls __init__ whenever + # it deserializes an object, we'll move its setup elsewhere. + self._initialized = False + self._pcollections_by_name = {} + self._stages_by_name = {} + + def setup(self, stages): + if self._initialized: + return + logging.debug("initialized the RayWatermarkManager") + self._initialized = True + watermark_manager.WatermarkManager.setup(self, stages) class RayRunnerExecutionContext(object): - def __init__(self, - stages: List[translations.Stage], - pipeline_components: beam_runner_api_pb2.Components, - safe_coders: translations.SafeCoderMapping, - data_channel_coders: Mapping[str, str], - state_servicer: Optional[RayStateManager] = None, - worker_manager: Optional[RayWorkerHandlerManager] = None, - pcollection_buffers: PcollectionBufferManager = None, - ) -> None: - ray.util.register_serializer( - beam_runner_api_pb2.Components, - serializer=lambda x: x.SerializeToString(), - deserializer=lambda s: beam_runner_api_pb2.Components.FromString(s)) - ray.util.register_serializer( - pipeline_context.PipelineContext, - serializer=lambda x: x.proto.SerializeToString(), - deserializer=lambda s: pipeline_context.PipelineContext( - proto=beam_runner_api_pb2.Components.FromString(s))) - - self.pcollection_buffers = ( - pcollection_buffers or PcollectionBufferManager.remote()) - self.state_servicer = state_servicer or RayStateManager() - self.stages = [RayStage.from_Stage(s) - if not isinstance(s, RayStage) else s for s in stages] - self.side_input_descriptors_by_stage = ( - fn_execution.FnApiRunnerExecutionContext._build_data_side_inputs_map( - stages)) - self.pipeline_components = pipeline_components - self.safe_coders = safe_coders - self.data_channel_coders = data_channel_coders - - self.input_transform_to_buffer_id = { - t.unique_name: bytes(t.spec.payload) - for s in stages for t in s.transforms - if t.spec.urn == bundle_processor.DATA_INPUT_URN - } - self._watermark_manager = RayWatermarkManager.remote() - self.pipeline_context = pipeline_context.PipelineContext( - pipeline_components) - self.safe_windowing_strategies = { - # TODO: Enable safe_windowing_strategy after - # figuring out how to pickle the function. - # id: self._make_safe_windowing_strategy(id) - id: id - for id in pipeline_components.windowing_strategies.keys() - } - self.stats = _RayRunnerStats.remote() - self._uid = 0 - self.worker_manager = worker_manager or RayWorkerHandlerManager() - self.timer_coder_ids = self._build_timer_coders_id_map() - - @property - def watermark_manager(self): - # We don't need to wait for this line to execute with ray.get, - # because any further calls to the watermark manager actor will - # have to wait for it. - self._watermark_manager.setup.remote(self.stages) - return self._watermark_manager - - @staticmethod - def next_uid(): - # TODO(pabloem): Use stats actor for UIDs. - # return str(ray.get(self.stats.next_bundle.remote())) - # self._uid += 1 - return str(random.randint(0, 11111111)) - - def _build_timer_coders_id_map(self): - # type: () -> Dict[Tuple[str, str], str] - from apache_beam.utils import proto_utils - timer_coder_ids = {} - for transform_id, transform_proto in (self.pipeline_components.transforms.items()): - if transform_proto.spec.urn == common_urns.primitives.PAR_DO.urn: - pardo_payload = proto_utils.parse_Bytes( - transform_proto.spec.payload, beam_runner_api_pb2.ParDoPayload) - for id, timer_family_spec in pardo_payload.timer_family_specs.items(): - timer_coder_ids[(transform_id, id)] = ( - timer_family_spec.timer_family_coder_id) - - def __reduce__(self): - # We need to implement custom serialization for this particular class - # because it contains several members that are protocol buffers, and - # protobufs are not pickleable due to being a C extension - so we serialize - # protobufs to string before serialzing the rest of the object. - data = (self.stages, + def __init__( + self, + stages: List[translations.Stage], + pipeline_components: beam_runner_api_pb2.Components, + safe_coders: translations.SafeCoderMapping, + data_channel_coders: Mapping[str, str], + state_servicer: Optional[RayStateManager] = None, + worker_manager: Optional[RayWorkerHandlerManager] = None, + pcollection_buffers: PcollectionBufferManager = None, + ) -> None: + ray.util.register_serializer( + beam_runner_api_pb2.Components, + serializer=lambda x: x.SerializeToString(), + deserializer=lambda s: beam_runner_api_pb2.Components.FromString(s), + ) + ray.util.register_serializer( + pipeline_context.PipelineContext, + serializer=lambda x: x.proto.SerializeToString(), + deserializer=lambda s: pipeline_context.PipelineContext( + proto=beam_runner_api_pb2.Components.FromString(s) + ), + ) + + self.pcollection_buffers = ( + pcollection_buffers or PcollectionBufferManager.remote() + ) + self.state_servicer = state_servicer or RayStateManager() + self.stages = [ + RayStage.from_Stage(s) if not isinstance(s, RayStage) else s for s in stages + ] + self.side_input_descriptors_by_stage = ( + fn_execution.FnApiRunnerExecutionContext._build_data_side_inputs_map(stages) + ) + self.pipeline_components = pipeline_components + self.safe_coders = safe_coders + self.data_channel_coders = data_channel_coders + + self.input_transform_to_buffer_id = { + t.unique_name: bytes(t.spec.payload) + for s in stages + for t in s.transforms + if t.spec.urn == bundle_processor.DATA_INPUT_URN + } + self._watermark_manager = RayWatermarkManager.remote() + self.pipeline_context = pipeline_context.PipelineContext(pipeline_components) + self.safe_windowing_strategies = { + # TODO: Enable safe_windowing_strategy after + # figuring out how to pickle the function. + # id: self._make_safe_windowing_strategy(id) + id: id + for id in pipeline_components.windowing_strategies.keys() + } + self.stats = _RayRunnerStats.remote() + self._uid = 0 + self.worker_manager = worker_manager or RayWorkerHandlerManager() + self.timer_coder_ids = self._build_timer_coders_id_map() + + @property + def watermark_manager(self): + # We don't need to wait for this line to execute with ray.get, + # because any further calls to the watermark manager actor will + # have to wait for it. + self._watermark_manager.setup.remote(self.stages) + return self._watermark_manager + + @staticmethod + def next_uid(): + # TODO(pabloem): Use stats actor for UIDs. + # return str(ray.get(self.stats.next_bundle.remote())) + # self._uid += 1 + return str(random.randint(0, 11111111)) + + def _build_timer_coders_id_map(self): + from apache_beam.utils import proto_utils + + timer_coder_ids = {} + for ( + transform_id, + transform_proto, + ) in self.pipeline_components.transforms.items(): + if transform_proto.spec.urn == common_urns.primitives.PAR_DO.urn: + pardo_payload = proto_utils.parse_Bytes( + transform_proto.spec.payload, beam_runner_api_pb2.ParDoPayload + ) + for id, timer_family_spec in pardo_payload.timer_family_specs.items(): + timer_coder_ids[ + (transform_id, id) + ] = timer_family_spec.timer_family_coder_id + + def __reduce__(self): + # We need to implement custom serialization for this particular class + # because it contains several members that are protocol buffers, and + # protobufs are not pickleable due to being a C extension - so we serialize + # protobufs to string before serialzing the rest of the object. + data = ( + self.stages, self.pipeline_components.SerializeToString(), self.safe_coders, self.data_channel_coders, self.state_servicer, self.worker_manager, - self.pcollection_buffers) - deserializer = lambda *args: RayRunnerExecutionContext( - args[0], beam_runner_api_pb2.Components.FromString(args[1]), args[2], - args[3], args[4], args[5], args[6]) - return (deserializer, data) + self.pcollection_buffers, + ) + + def deserializer(*args): + return RayRunnerExecutionContext( + args[0], + beam_runner_api_pb2.Components.FromString(args[1]), + args[2], + args[3], + args[4], + args[5], + args[6], + ) + + return (deserializer, data) def merge_stage_results( previous_result: beam_fn_api_pb2.InstructionResponse, - last_result: beam_fn_api_pb2.InstructionResponse + last_result: beam_fn_api_pb2.InstructionResponse, ) -> beam_fn_api_pb2.InstructionResponse: - """ Merge InstructionResponse objects from executions of same stage bundles. - - This method is used to produce a global per-stage result object with - aggregated metrics and results. - """ - return ( - last_result - if previous_result is None else beam_fn_api_pb2.InstructionResponse( - process_bundle=beam_fn_api_pb2.ProcessBundleResponse( - monitoring_infos=monitoring_infos.consolidate( - itertools.chain( - previous_result.process_bundle.monitoring_infos, - last_result.process_bundle.monitoring_infos))), - error=previous_result.error or last_result.error)) - - -def _get_worker_handler(runner_context: RayRunnerExecutionContext, bundle_descriptor_id) -> worker_handlers.WorkerHandler: - worker_handler = worker_handlers.EmbeddedWorkerHandler( - None, # Unnecessary payload. - runner_context.state_servicer, - None, # Unnecessary provision info. - runner_context.worker_manager, - ) - worker_handler.worker.bundle_processor_cache.register( - runner_context.worker_manager.process_bundle_descriptor(bundle_descriptor_id) - ) - return worker_handler + """Merge InstructionResponse objects from executions of same stage bundles. + + This method is used to produce a global per-stage result object with + aggregated metrics and results. + """ + return ( + last_result + if previous_result is None + else beam_fn_api_pb2.InstructionResponse( + process_bundle=beam_fn_api_pb2.ProcessBundleResponse( + monitoring_infos=monitoring_infos.consolidate( + itertools.chain( + previous_result.process_bundle.monitoring_infos, + last_result.process_bundle.monitoring_infos, + ) + ) + ), + error=previous_result.error or last_result.error, + ) + ) + + +def _get_worker_handler( + runner_context: RayRunnerExecutionContext, bundle_descriptor_id +) -> worker_handlers.WorkerHandler: + worker_handler = worker_handlers.EmbeddedWorkerHandler( + None, # Unnecessary payload. + runner_context.state_servicer, + None, # Unnecessary provision info. + runner_context.worker_manager, + ) + worker_handler.worker.bundle_processor_cache.register( + runner_context.worker_manager.process_bundle_descriptor(bundle_descriptor_id) + ) + return worker_handler @dataclasses.dataclass class Bundle: - input_timers: Mapping[translations.TimerFamilyId, fn_execution.PartitionableBuffer] - input_data: Mapping[str, List[ray.ObjectRef]] \ No newline at end of file + input_timers: Mapping[translations.TimerFamilyId, fn_execution.PartitionableBuffer] + input_data: Mapping[str, List[ray.ObjectRef]] diff --git a/ray_beam_runner/portability/execution_test.py b/ray_beam_runner/portability/execution_test.py index fa966ef..3042f37 100644 --- a/ray_beam_runner/portability/execution_test.py +++ b/ray_beam_runner/portability/execution_test.py @@ -8,36 +8,35 @@ class StateHandlerTest(unittest.TestCase): - SAMPLE_STATE_KEY = apache_beam.portability.api.beam_fn_api_pb2.StateKey() - SAMPLE_INPUT_DATA = [ - b'bobby' - b'tables', - b'drop table', - b'where table_name > 12345' - ] - - @classmethod - def setUpClass(cls) -> None: - if not ray.is_initialized(): - ray.init() - - @classmethod - def tearDownClass(cls) -> None: - ray.shutdown() - - def test_data_stored_properly(self): - sh = RayStateManager() - with sh.process_instruction_id('anyinstruction'): - for data in StateHandlerTest.SAMPLE_INPUT_DATA: - sh.append_raw(StateHandlerTest.SAMPLE_STATE_KEY, data) - - with sh.process_instruction_id('anyinstruction'): - continuation_token = None - all_data = [] - while True: - data, continuation_token = sh.get_raw(StateHandlerTest.SAMPLE_STATE_KEY, continuation_token) - all_data.append(data) - if continuation_token is None: - break - - hc.assert_that(all_data, hc.contains_exactly(*StateHandlerTest.SAMPLE_INPUT_DATA)) + SAMPLE_STATE_KEY = apache_beam.portability.api.beam_fn_api_pb2.StateKey() + SAMPLE_INPUT_DATA = [b"bobby" b"tables", b"drop table", b"where table_name > 12345"] + + @classmethod + def setUpClass(cls) -> None: + if not ray.is_initialized(): + ray.init() + + @classmethod + def tearDownClass(cls) -> None: + ray.shutdown() + + def test_data_stored_properly(self): + sh = RayStateManager() + with sh.process_instruction_id("anyinstruction"): + for data in StateHandlerTest.SAMPLE_INPUT_DATA: + sh.append_raw(StateHandlerTest.SAMPLE_STATE_KEY, data) + + with sh.process_instruction_id("anyinstruction"): + continuation_token = None + all_data = [] + while True: + data, continuation_token = sh.get_raw( + StateHandlerTest.SAMPLE_STATE_KEY, continuation_token + ) + all_data.append(data) + if continuation_token is None: + break + + hc.assert_that( + all_data, hc.contains_exactly(*StateHandlerTest.SAMPLE_INPUT_DATA) + ) diff --git a/ray_beam_runner/portability/ray_fn_runner.py b/ray_beam_runner/portability/ray_fn_runner.py index 6c572e9..1388afb 100644 --- a/ray_beam_runner/portability/ray_fn_runner.py +++ b/ray_beam_runner/portability/ray_fn_runner.py @@ -42,13 +42,15 @@ from apache_beam.runners.portability.fn_api_runner import translations from apache_beam.runners.portability.fn_api_runner.execution import ListBuffer from apache_beam.transforms import environments -from apache_beam.utils import timestamp from apache_beam.utils import proto_utils import ray from ray_beam_runner.portability.context_management import RayBundleContextManager from ray_beam_runner.portability.execution import Bundle, _get_input_id -from ray_beam_runner.portability.execution import ray_execute_bundle, merge_stage_results +from ray_beam_runner.portability.execution import ( + ray_execute_bundle, + merge_stage_results, +) from ray_beam_runner.portability.execution import RayRunnerExecutionContext _LOGGER = logging.getLogger(__name__) @@ -57,259 +59,310 @@ def _setup_options(options: pipeline_options.PipelineOptions): - """Perform any necessary checkups and updates to input pipeline options""" - - # TODO(pabloem): Add input pipeline options - RuntimeValueProvider.set_runtime_options({}) - - experiments = ( - options.view_as(pipeline_options.DebugOptions).experiments or []) - if not 'beam_fn_api' in experiments: - experiments.append('beam_fn_api') - options.view_as(pipeline_options.DebugOptions).experiments = experiments - - -def _check_supported_requirements(pipeline_proto: beam_runner_api_pb2.Pipeline, supported_requirements: typing.Iterable[str]): - """Check that the input pipeline does not have unsuported requirements.""" - for requirement in pipeline_proto.requirements: - if requirement not in supported_requirements: - raise ValueError( - 'Unable to run pipeline with requirement: %s' % requirement) - for transform in pipeline_proto.components.transforms.values(): - if transform.spec.urn == common_urns.primitives.TEST_STREAM.urn: - raise NotImplementedError(transform.spec.urn) - elif transform.spec.urn in translations.PAR_DO_URNS: - payload = proto_utils.parse_Bytes( - transform.spec.payload, beam_runner_api_pb2.ParDoPayload) - for timer in payload.timer_family_specs.values(): - if timer.time_domain != beam_runner_api_pb2.TimeDomain.EVENT_TIME: - raise NotImplementedError(timer.time_domain) - - -def _pipeline_checks(pipeline: Pipeline, options: pipeline_options.PipelineOptions, supported_requirements: typing.Iterable[str]): - # This is sometimes needed if type checking is disabled - # to enforce that the inputs (and outputs) of GroupByKey operations - # are known to be KVs. - pipeline.visit( - group_by_key_input_visitor( - not options.view_as(pipeline_options.TypeOptions). - allow_non_deterministic_key_coders)) - - pipeline_proto = pipeline.to_runner_api(default_environment=environments.EmbeddedPythonEnvironment.default()) - fn_runner.FnApiRunner._validate_requirements(None, pipeline_proto) - - _check_supported_requirements(pipeline_proto, supported_requirements) - return pipeline_proto - - -class RayFnApiRunner(runner.PipelineRunner): + """Perform any necessary checkups and updates to input pipeline options""" + + # TODO(pabloem): Add input pipeline options + RuntimeValueProvider.set_runtime_options({}) + + experiments = options.view_as(pipeline_options.DebugOptions).experiments or [] + if "beam_fn_api" not in experiments: + experiments.append("beam_fn_api") + options.view_as(pipeline_options.DebugOptions).experiments = experiments + + +def _check_supported_requirements( + pipeline_proto: beam_runner_api_pb2.Pipeline, + supported_requirements: typing.Iterable[str], +): + """Check that the input pipeline does not have unsuported requirements.""" + for requirement in pipeline_proto.requirements: + if requirement not in supported_requirements: + raise ValueError( + "Unable to run pipeline with requirement: %s" % requirement + ) + for transform in pipeline_proto.components.transforms.values(): + if transform.spec.urn == common_urns.primitives.TEST_STREAM.urn: + raise NotImplementedError(transform.spec.urn) + elif transform.spec.urn in translations.PAR_DO_URNS: + payload = proto_utils.parse_Bytes( + transform.spec.payload, beam_runner_api_pb2.ParDoPayload + ) + for timer in payload.timer_family_specs.values(): + if timer.time_domain != beam_runner_api_pb2.TimeDomain.EVENT_TIME: + raise NotImplementedError(timer.time_domain) + + +def _pipeline_checks( + pipeline: Pipeline, + options: pipeline_options.PipelineOptions, + supported_requirements: typing.Iterable[str], +): + # This is sometimes needed if type checking is disabled + # to enforce that the inputs (and outputs) of GroupByKey operations + # are known to be KVs. + pipeline.visit( + group_by_key_input_visitor( + not options.view_as( + pipeline_options.TypeOptions + ).allow_non_deterministic_key_coders + ) + ) - def __init__( - self, - ) -> None: - - """Creates a new Ray Runner instance. - - Args: - progress_request_frequency: The frequency (in seconds) that the runner - waits before requesting progress from the SDK. - """ - super().__init__() - # TODO: figure out if this is necessary (probably, later) - self._progress_frequency = None - self._cache_token_generator = fn_runner.FnApiRunner.get_cache_token_generator() - - @staticmethod - def supported_requirements(): - # type: () -> Tuple[str, ...] - return ( - common_urns.requirements.REQUIRES_STATEFUL_PROCESSING.urn, - common_urns.requirements.REQUIRES_BUNDLE_FINALIZATION.urn, - common_urns.requirements.REQUIRES_SPLITTABLE_DOFN.urn, + pipeline_proto = pipeline.to_runner_api( + default_environment=environments.EmbeddedPythonEnvironment.default() ) + fn_runner.FnApiRunner._validate_requirements(None, pipeline_proto) - def run_pipeline(self, - pipeline: Pipeline, - options: pipeline_options.PipelineOptions - ) -> 'RayRunnerResult': - - # Checkup and set up input pipeline options - _setup_options(options) - - # Check pipeline and convert into protocol buffer representation - pipeline_proto = _pipeline_checks(pipeline, options, self.supported_requirements()) - - # Take the protocol buffer representation of the user's pipeline, and - # apply optimizations. - stage_context, stages = translations.create_and_optimize_stages( - copy.deepcopy(pipeline_proto), - phases=[ - # This is a list of transformations and optimizations to apply - # to a pipeline. - translations.annotate_downstream_side_inputs, - translations.fix_side_input_pcoll_coders, - translations.pack_combiners, - translations.lift_combiners, - translations.expand_sdf, - translations.expand_gbk, - translations.sink_flattens, - translations.greedily_fuse, - translations.read_to_impulse, - translations.impulse_to_input, - translations.sort_stages, - translations.setup_timer_mapping, - translations.populate_data_channel_coders, - ], - known_runner_urns=frozenset([ - common_urns.primitives.FLATTEN.urn, - common_urns.primitives.GROUP_BY_KEY.urn, - ]), - use_state_iterables=False, - is_drain=False) - return self.execute_pipeline(stage_context, stages) - - def execute_pipeline(self, - stage_context: translations.TransformContext, - stages: List[translations.Stage] - ) -> 'RayRunnerResult': - """Execute pipeline represented by a list of stages and a context.""" - logging.info('Starting pipeline of %d stages.' % len(stages)) - - runner_execution_context = RayRunnerExecutionContext( - stages, - stage_context.components, - stage_context.safe_coders, - stage_context.data_channel_coders) - - # Using this queue to hold 'bundles' that are ready to be processed - queue = collections.deque() - - try: - for stage in stages: - bundle_ctx = RayBundleContextManager(runner_execution_context, stage) - self._run_stage(runner_execution_context, bundle_ctx, queue) - finally: - pass - return RayRunnerResult(runner.PipelineState.DONE) - - def _run_stage(self, - runner_execution_context: RayRunnerExecutionContext, - bundle_context_manager: RayBundleContextManager, - ready_bundles: collections.deque, - ) -> beam_fn_api_pb2.InstructionResponse: - - """Run an individual stage. - - Args: - runner_execution_context: An object containing execution information for - the pipeline. - bundle_context_manager (execution.BundleContextManager): A description of - the stage to execute, and its context. - """ - bundle_context_manager.setup() - runner_execution_context.worker_manager.register_process_bundle_descriptor( - bundle_context_manager.process_bundle_descriptor) - input_timers: Mapping[translations.TimerFamilyId, - execution.PartitionableBuffer] = {} - - input_data = { - k: ray.get(runner_execution_context.pcollection_buffers.get.remote(_get_input_id(bundle_context_manager.transform_to_buffer_coder[k][0], k))) - for k in bundle_context_manager.transform_to_buffer_coder} - - final_result = None # type: Optional[beam_fn_api_pb2.InstructionResponse] - - while True: - last_result, fired_timers, delayed_applications, bundle_outputs = ( - self._run_bundle( - runner_execution_context, - bundle_context_manager, - Bundle(input_timers=input_timers, input_data=input_data))) - - final_result = merge_stage_results(final_result, last_result) - if not delayed_applications and not fired_timers: - break - else: - # TODO: Enable following assertion after watermarking is implemented - # assert (ray.get(runner_execution_context.watermark_manager.get_stage_node.remote( - # bundle_context_manager.stage.name)).output_watermark() - # < timestamp.MAX_TIMESTAMP), ( - # 'wrong timestamp for %s. ' - # % ray.get(runner_execution_context.watermark_manager.get_stage_node.remote( - # bundle_context_manager.stage.name))) - input_data = delayed_applications - input_timers = fired_timers - - # Store the required downstream side inputs into state so it is accessible - # for the worker when it runs bundles that consume this stage's output. - data_side_input = ( - runner_execution_context.side_input_descriptors_by_stage.get( - bundle_context_manager.stage.name, {})) - # TODO(pabloem): Make sure that side inputs are being stored somewhere. - # runner_execution_context.commit_side_inputs_to_state(data_side_input) - - return final_result - - def _run_bundle( - self, - runner_execution_context: RayRunnerExecutionContext, - bundle_context_manager: RayBundleContextManager, - input_bundle: Bundle - ) -> Tuple[beam_fn_api_pb2.InstructionResponse, - Dict[translations.TimerFamilyId, ListBuffer], - Mapping[str, ray.ObjectRef], - List[Union[str, translations.TimerFamilyId]]]: - """Execute a bundle, and return a result object, and deferred inputs.""" - transform_to_buffer_coder, data_output, stage_timers = ( - bundle_context_manager.get_bundle_inputs_and_outputs()) - - cache_token_generator = fn_runner.FnApiRunner.get_cache_token_generator(static=False) - - # TODO(pabloem): Are there two different IDs? the Bundle ID and PBD ID? - process_bundle_id = 'bundle_%s' % bundle_context_manager.process_bundle_descriptor.id - - (result_str, output, delayed_applications) = ray.get(ray_execute_bundle.remote( - runner_execution_context, - input_bundle, - transform_to_buffer_coder, - data_output, - stage_timers, - instruction_request_repr={ - 'instruction_id': process_bundle_id, - 'process_descriptor_id': bundle_context_manager.process_bundle_descriptor.id, - 'cache_token': next(cache_token_generator) - } - )) - result = beam_fn_api_pb2.InstructionResponse.FromString(result_str) + _check_supported_requirements(pipeline_proto, supported_requirements) + return pipeline_proto - # TODO(pabloem): Add support for splitting of results. - # After collecting deferred inputs, we 'pad' the structure with empty - # buffers for other expected inputs. - # if deferred_inputs or newly_set_timers: - # # The worker will be waiting on these inputs as well. - # for other_input in data_input: - # if other_input not in deferred_inputs: - # deferred_inputs[other_input] = ListBuffer( - # coder_impl=bundle_context_manager.get_input_coder_impl( - # other_input)) +class RayFnApiRunner(runner.PipelineRunner): + def __init__( + self, + ) -> None: + + """Creates a new Ray Runner instance. + + Args: + progress_request_frequency: The frequency (in seconds) that the runner + waits before requesting progress from the SDK. + """ + super().__init__() + # TODO: figure out if this is necessary (probably, later) + self._progress_frequency = None + self._cache_token_generator = fn_runner.FnApiRunner.get_cache_token_generator() + + @staticmethod + def supported_requirements(): + # type: () -> Tuple[str, ...] + return ( + common_urns.requirements.REQUIRES_STATEFUL_PROCESSING.urn, + common_urns.requirements.REQUIRES_BUNDLE_FINALIZATION.urn, + common_urns.requirements.REQUIRES_SPLITTABLE_DOFN.urn, + ) + + def run_pipeline( + self, pipeline: Pipeline, options: pipeline_options.PipelineOptions + ) -> "RayRunnerResult": + + # Checkup and set up input pipeline options + _setup_options(options) + + # Check pipeline and convert into protocol buffer representation + pipeline_proto = _pipeline_checks( + pipeline, options, self.supported_requirements() + ) + + # Take the protocol buffer representation of the user's pipeline, and + # apply optimizations. + stage_context, stages = translations.create_and_optimize_stages( + copy.deepcopy(pipeline_proto), + phases=[ + # This is a list of transformations and optimizations to apply + # to a pipeline. + translations.annotate_downstream_side_inputs, + translations.fix_side_input_pcoll_coders, + translations.pack_combiners, + translations.lift_combiners, + translations.expand_sdf, + translations.expand_gbk, + translations.sink_flattens, + translations.greedily_fuse, + translations.read_to_impulse, + translations.impulse_to_input, + translations.sort_stages, + translations.setup_timer_mapping, + translations.populate_data_channel_coders, + ], + known_runner_urns=frozenset( + [ + common_urns.primitives.FLATTEN.urn, + common_urns.primitives.GROUP_BY_KEY.urn, + ] + ), + use_state_iterables=False, + is_drain=False, + ) + return self.execute_pipeline(stage_context, stages) + + def execute_pipeline( + self, + stage_context: translations.TransformContext, + stages: List[translations.Stage], + ) -> "RayRunnerResult": + """Execute pipeline represented by a list of stages and a context.""" + logging.info("Starting pipeline of %d stages." % len(stages)) + + runner_execution_context = RayRunnerExecutionContext( + stages, + stage_context.components, + stage_context.safe_coders, + stage_context.data_channel_coders, + ) + + # Using this queue to hold 'bundles' that are ready to be processed + queue = collections.deque() + + try: + for stage in stages: + bundle_ctx = RayBundleContextManager(runner_execution_context, stage) + self._run_stage(runner_execution_context, bundle_ctx, queue) + finally: + pass + return RayRunnerResult(runner.PipelineState.DONE) + + def _run_stage( + self, + runner_execution_context: RayRunnerExecutionContext, + bundle_context_manager: RayBundleContextManager, + ready_bundles: collections.deque, + ) -> beam_fn_api_pb2.InstructionResponse: + + """Run an individual stage. + + Args: + runner_execution_context: An object containing execution information for + the pipeline. + bundle_context_manager (execution.BundleContextManager): A description of + the stage to execute, and its context. + """ + bundle_context_manager.setup() + runner_execution_context.worker_manager.register_process_bundle_descriptor( + bundle_context_manager.process_bundle_descriptor + ) + input_timers: Mapping[ + translations.TimerFamilyId, execution.PartitionableBuffer + ] = {} + + input_data = { + k: ray.get( + runner_execution_context.pcollection_buffers.get.remote( + _get_input_id( + bundle_context_manager.transform_to_buffer_coder[k][0], k + ) + ) + ) + for k in bundle_context_manager.transform_to_buffer_coder + } - newly_set_timers = {} - return result, newly_set_timers, delayed_applications, output + final_result = None # type: Optional[beam_fn_api_pb2.InstructionResponse] + + while True: + ( + last_result, + fired_timers, + delayed_applications, + bundle_outputs, + ) = self._run_bundle( + runner_execution_context, + bundle_context_manager, + Bundle(input_timers=input_timers, input_data=input_data), + ) + + final_result = merge_stage_results(final_result, last_result) + if not delayed_applications and not fired_timers: + break + else: + # TODO: Enable following assertion after watermarking is implemented + # assert (ray.get( + # runner_execution_context.watermark_manager + # .get_stage_node.remote( + # bundle_context_manager.stage.name)).output_watermark() + # < timestamp.MAX_TIMESTAMP), ( + # 'wrong timestamp for %s. ' + # % ray.get( + # runner_execution_context.watermark_manager + # .get_stage_node.remote( + # bundle_context_manager.stage.name))) + input_data = delayed_applications + input_timers = fired_timers + + # Store the required downstream side inputs into state so it is accessible + # for the worker when it runs bundles that consume this stage's output. + # data_side_input = ( + # runner_execution_context.side_input_descriptors_by_stage.get( + # bundle_context_manager.stage.name, {} + # ) + # ) + # TODO(pabloem): Make sure that side inputs are being stored somewhere. + # runner_execution_context.commit_side_inputs_to_state(data_side_input) + + return final_result + + def _run_bundle( + self, + runner_execution_context: RayRunnerExecutionContext, + bundle_context_manager: RayBundleContextManager, + input_bundle: Bundle, + ) -> Tuple[ + beam_fn_api_pb2.InstructionResponse, + Dict[translations.TimerFamilyId, ListBuffer], + Mapping[str, ray.ObjectRef], + List[Union[str, translations.TimerFamilyId]], + ]: + """Execute a bundle, and return a result object, and deferred inputs.""" + ( + transform_to_buffer_coder, + data_output, + stage_timers, + ) = bundle_context_manager.get_bundle_inputs_and_outputs() + + cache_token_generator = fn_runner.FnApiRunner.get_cache_token_generator( + static=False + ) + + # TODO(pabloem): Are there two different IDs? the Bundle ID and PBD ID? + process_bundle_id = ( + "bundle_%s" % bundle_context_manager.process_bundle_descriptor.id + ) + + pbd_id = bundle_context_manager.process_bundle_descriptor.id + (result_str, output, delayed_applications) = ray.get( + ray_execute_bundle.remote( + runner_execution_context, + input_bundle, + transform_to_buffer_coder, + data_output, + stage_timers, + instruction_request_repr={ + "instruction_id": process_bundle_id, + "process_descriptor_id": pbd_id, + "cache_token": next(cache_token_generator), + }, + ) + ) + result = beam_fn_api_pb2.InstructionResponse.FromString(result_str) + + # TODO(pabloem): Add support for splitting of results. + + # After collecting deferred inputs, we 'pad' the structure with empty + # buffers for other expected inputs. + # if deferred_inputs or newly_set_timers: + # # The worker will be waiting on these inputs as well. + # for other_input in data_input: + # if other_input not in deferred_inputs: + # deferred_inputs[other_input] = ListBuffer( + # coder_impl=bundle_context_manager.get_input_coder_impl( + # other_input)) + + newly_set_timers = {} + return result, newly_set_timers, delayed_applications, output class RayRunnerResult(runner.PipelineResult): - def __init__(self, state): - super().__init__(state) + def __init__(self, state): + super().__init__(state) - def wait_until_finish(self, duration=None): - return None + def wait_until_finish(self, duration=None): + return None - def metrics(self): - """Returns a queryable object including user metrics only.""" - # TODO(pabloem): Implement this based on _RayMetricsActor - raise NotImplementedError() + def metrics(self): + """Returns a queryable object including user metrics only.""" + # TODO(pabloem): Implement this based on _RayMetricsActor + raise NotImplementedError() - def monitoring_metrics(self): - """Returns a queryable object including all metrics.""" - # TODO(pabloem): Implement this based on _RayMetricsActor - raise NotImplementedError() + def monitoring_metrics(self): + """Returns a queryable object including all metrics.""" + # TODO(pabloem): Implement this based on _RayMetricsActor + raise NotImplementedError() diff --git a/ray_beam_runner/portability/ray_runner_test.py b/ray_beam_runner/portability/ray_runner_test.py index d83919e..7a67e34 100644 --- a/ray_beam_runner/portability/ray_runner_test.py +++ b/ray_beam_runner/portability/ray_runner_test.py @@ -35,7 +35,6 @@ from typing import Tuple import hamcrest # pylint: disable=ungrouped-imports -import pytest from hamcrest.core.matcher import Matcher from hamcrest.core.string_description import StringDescription @@ -51,7 +50,6 @@ from apache_beam.options.pipeline_options import StandardOptions from apache_beam.options.value_provider import RuntimeValueProvider from apache_beam.portability import python_urns -from apache_beam.runners.portability import fn_api_runner from apache_beam.runners.portability.fn_api_runner import fn_runner from apache_beam.runners.sdf_utils import RestrictionTrackerView from apache_beam.runners.worker import data_plane @@ -60,7 +58,6 @@ from apache_beam.testing.test_stream import TestStream from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to -from apache_beam.transforms import environments from apache_beam.transforms import userstate from apache_beam.transforms import window from apache_beam.utils import timestamp @@ -69,1043 +66,1147 @@ import ray if statesampler.FAST_SAMPLER: - DEFAULT_SAMPLING_PERIOD_MS = statesampler.DEFAULT_SAMPLING_PERIOD_MS + DEFAULT_SAMPLING_PERIOD_MS = statesampler.DEFAULT_SAMPLING_PERIOD_MS else: - DEFAULT_SAMPLING_PERIOD_MS = 0 + DEFAULT_SAMPLING_PERIOD_MS = 0 _LOGGER = logging.getLogger(__name__) def _matcher_or_equal_to(value_or_matcher): - """Pass-thru for matchers, and wraps value inputs in an equal_to matcher.""" - if value_or_matcher is None: - return None - if isinstance(value_or_matcher, Matcher): - return value_or_matcher - return hamcrest.equal_to(value_or_matcher) + """Pass-thru for matchers, and wraps value inputs in an equal_to matcher.""" + if value_or_matcher is None: + return None + if isinstance(value_or_matcher, Matcher): + return value_or_matcher + return hamcrest.equal_to(value_or_matcher) def has_urn_and_labels(mi, urn, labels): - """Returns true if it the monitoring_info contains the labels and urn.""" - def contains_labels(mi, labels): - # Check all the labels and their values exist in the monitoring_info - return all(item in mi.labels.items() for item in labels.items()) + """Returns true if it the monitoring_info contains the labels and urn.""" - return contains_labels(mi, labels) and mi.urn == urn + def contains_labels(mi, labels): + # Check all the labels and their values exist in the monitoring_info + return all(item in mi.labels.items() for item in labels.items()) + return contains_labels(mi, labels) and mi.urn == urn -class RayFnApiRunnerTest(unittest.TestCase): - @classmethod - def setUpClass(cls) -> None: - if not ray.is_initialized(): - ray.init(local_mode=True) - - def create_pipeline(self, is_drain=False): - return beam.Pipeline(runner=ray_beam_runner.portability.ray_fn_runner.RayFnApiRunner()) - - def test_assert_that(self): - with self.assertRaisesRegex(Exception, 'Failed assert'): - with self.create_pipeline() as p: - assert_that(p | beam.Create(['a', 'b']), equal_to(['a'])) - - def test_create(self): - with self.create_pipeline() as p: - assert_that(p | beam.Create(['a', 'b']), equal_to(['a', 'b'])) - - def test_pardo(self): - with self.create_pipeline() as p: - res = ( - p - | beam.Create(['a', 'bc']) - | beam.Map(lambda e: e * 2) - | beam.Map(lambda e: e + 'x')) - assert_that(res, equal_to(['aax', 'bcbcx'])) - - def test_pardo_side_outputs(self): - def tee(elem, *tags): - for tag in tags: - if tag in elem: - yield beam.pvalue.TaggedOutput(tag, elem) - - with self.create_pipeline() as p: - xy = ( - p - | 'Create' >> beam.Create(['x', 'y', 'xy']) - | beam.FlatMap(tee, 'x', 'y').with_outputs()) - assert_that(xy.x, equal_to(['x', 'xy']), label='x') - assert_that(xy.y, equal_to(['y', 'xy']), label='y') - - def test_pardo_side_and_main_outputs(self): - def even_odd(elem): - yield elem - yield beam.pvalue.TaggedOutput('odd' if elem % 2 else 'even', elem) - - with self.create_pipeline() as p: - ints = p | beam.Create([1, 2, 3]) - named = ints | 'named' >> beam.FlatMap(even_odd).with_outputs( - 'even', 'odd', main='all') - assert_that(named.all, equal_to([1, 2, 3]), label='named.all') - assert_that(named.even, equal_to([2]), label='named.even') - assert_that(named.odd, equal_to([1, 3]), label='named.odd') - - unnamed = ints | 'unnamed' >> beam.FlatMap(even_odd).with_outputs() - unnamed[None] | beam.Map(id) # pylint: disable=expression-not-assigned - assert_that(unnamed[None], equal_to([1, 2, 3]), label='unnamed.all') - assert_that(unnamed.even, equal_to([2]), label='unnamed.even') - assert_that(unnamed.odd, equal_to([1, 3]), label='unnamed.odd') - - @unittest.skip('Side inputs not yet supported') - def test_pardo_side_inputs(self): - def cross_product(elem, sides): - for side in sides: - yield elem, side - - with self.create_pipeline() as p: - main = p | 'main' >> beam.Create(['a', 'b', 'c']) - side = p | 'side' >> beam.Create(['x', 'y']) - assert_that( - main | beam.FlatMap(cross_product, beam.pvalue.AsList(side)), - equal_to([('a', 'x'), ('b', 'x'), ('c', 'x'), ('a', 'y'), ('b', 'y'), - ('c', 'y')])) - - @unittest.skip('Side inputs not yet supported') - def test_pardo_side_input_dependencies(self): - with self.create_pipeline() as p: - inputs = [p | beam.Create([None])] - for k in range(1, 10): - inputs.append( - inputs[0] | beam.ParDo( - ExpectingSideInputsFn(f'Do{k}'), - *[beam.pvalue.AsList(inputs[s]) for s in range(1, k)])) - - @unittest.skip('Side inputs not yet supported') - def test_pardo_side_input_sparse_dependencies(self): - with self.create_pipeline() as p: - inputs = [] - - def choose_input(s): - return inputs[(389 + s * 5077) % len(inputs)] - - for k in range(30): - num_inputs = int((k * k % 16)**0.5) - if num_inputs == 0: - inputs.append(p | f'Create{k}' >> beam.Create([f'Create{k}'])) - else: - inputs.append( - choose_input(0) | beam.ParDo( - ExpectingSideInputsFn(f'Do{k}'), - *[ - beam.pvalue.AsList(choose_input(s)) - for s in range(1, num_inputs) - ])) - - @unittest.skip('Side inputs not yet supported') - def test_pardo_windowed_side_inputs(self): - with self.create_pipeline() as p: - # Now with some windowing. - pcoll = p | beam.Create(list( - range(10))) | beam.Map(lambda t: window.TimestampedValue(t, t)) - # Intentionally choosing non-aligned windows to highlight the transition. - main = pcoll | 'WindowMain' >> beam.WindowInto(window.FixedWindows(5)) - side = pcoll | 'WindowSide' >> beam.WindowInto(window.FixedWindows(7)) - res = main | beam.Map( - lambda x, s: (x, sorted(s)), beam.pvalue.AsList(side)) - assert_that( - res, - equal_to([ - # The window [0, 5) maps to the window [0, 7). - (0, list(range(7))), - (1, list(range(7))), - (2, list(range(7))), - (3, list(range(7))), - (4, list(range(7))), - # The window [5, 10) maps to the window [7, 14). - (5, list(range(7, 10))), - (6, list(range(7, 10))), - (7, list(range(7, 10))), - (8, list(range(7, 10))), - (9, list(range(7, 10))) - ]), - label='windowed') - - @unittest.skip('Side inputs not yet supported') - def test_flattened_side_input(self, with_transcoding=True): - with self.create_pipeline() as p: - main = p | 'main' >> beam.Create([None]) - side1 = p | 'side1' >> beam.Create([('a', 1)]) - side2 = p | 'side2' >> beam.Create([('b', 2)]) - if with_transcoding: - # Also test non-matching coder types (transcoding required) - third_element = [('another_type')] - else: - third_element = [('b', 3)] - side3 = p | 'side3' >> beam.Create(third_element) - side = (side1, side2) | beam.Flatten() - assert_that( - main | beam.Map(lambda a, b: (a, b), beam.pvalue.AsDict(side)), - equal_to([(None, { - 'a': 1, 'b': 2 - })]), - label='CheckFlattenAsSideInput') - assert_that((side, side3) | 'FlattenAfter' >> beam.Flatten(), - equal_to([('a', 1), ('b', 2)] + third_element), - label='CheckFlattenOfSideInput') - - @unittest.skip('Side inputs not yet supported') - def test_gbk_side_input(self): - with self.create_pipeline() as p: - main = p | 'main' >> beam.Create([None]) - side = p | 'side' >> beam.Create([('a', 1)]) | beam.GroupByKey() - assert_that( - main | beam.Map(lambda a, b: (a, b), beam.pvalue.AsDict(side)), - equal_to([(None, { - 'a': [1] - })])) - - @unittest.skip('Side inputs not yet supported') - def test_multimap_side_input(self): - with self.create_pipeline() as p: - main = p | 'main' >> beam.Create(['a', 'b']) - side = p | 'side' >> beam.Create([('a', 1), ('b', 2), ('a', 3)]) - assert_that( - main | beam.Map( - lambda k, d: (k, sorted(d[k])), beam.pvalue.AsMultiMap(side)), - equal_to([('a', [1, 3]), ('b', [2])])) - - @unittest.skip('Side inputs not yet supported') - def test_multimap_multiside_input(self): - # A test where two transforms in the same stage consume the same PCollection - # twice as side input. - with self.create_pipeline() as p: - main = p | 'main' >> beam.Create(['a', 'b']) - side = p | 'side' >> beam.Create([('a', 1), ('b', 2), ('a', 3)]) - assert_that( - main | 'first map' >> beam.Map( - lambda k, - d, - l: (k, sorted(d[k]), sorted([e[1] for e in l])), - beam.pvalue.AsMultiMap(side), - beam.pvalue.AsList(side)) - | 'second map' >> beam.Map( - lambda k, - d, - l: (k[0], sorted(d[k[0]]), sorted([e[1] for e in l])), - beam.pvalue.AsMultiMap(side), - beam.pvalue.AsList(side)), - equal_to([('a', [1, 3], [1, 2, 3]), ('b', [2], [1, 2, 3])])) - - @unittest.skip('Side inputs not yet supported') - def test_multimap_side_input_type_coercion(self): - with self.create_pipeline() as p: - main = p | 'main' >> beam.Create(['a', 'b']) - # The type of this side-input is forced to Any (overriding type - # inference). Without type coercion to Tuple[Any, Any], the usage of this - # side-input in AsMultiMap() below should fail. - side = ( - p | 'side' >> beam.Create([('a', 1), ('b', 2), - ('a', 3)]).with_output_types(typing.Any)) - assert_that( - main | beam.Map( - lambda k, d: (k, sorted(d[k])), beam.pvalue.AsMultiMap(side)), - equal_to([('a', [1, 3]), ('b', [2])])) - - @unittest.skip('Side inputs not yet supported') - def test_pardo_unfusable_side_inputs(self): - def cross_product(elem, sides): - for side in sides: - yield elem, side - - with self.create_pipeline() as p: - pcoll = p | beam.Create(['a', 'b']) - assert_that( - pcoll | beam.FlatMap(cross_product, beam.pvalue.AsList(pcoll)), - equal_to([('a', 'a'), ('a', 'b'), ('b', 'a'), ('b', 'b')])) - - with self.create_pipeline() as p: - pcoll = p | beam.Create(['a', 'b']) - derived = ((pcoll, ) | beam.Flatten() - | beam.Map(lambda x: (x, x)) - | beam.GroupByKey() - | 'Unkey' >> beam.Map(lambda kv: kv[0])) - assert_that( - pcoll | beam.FlatMap(cross_product, beam.pvalue.AsList(derived)), - equal_to([('a', 'a'), ('a', 'b'), ('b', 'a'), ('b', 'b')])) - - @unittest.skip('State not yet supported') - def test_pardo_state_only(self): - index_state_spec = userstate.CombiningValueStateSpec('index', sum) - value_and_index_state_spec = userstate.ReadModifyWriteStateSpec( - 'value:index', StrUtf8Coder()) - - class AddIndex(beam.DoFn): - def process( - self, - kv, - index=beam.DoFn.StateParam(index_state_spec), - value_and_index=beam.DoFn.StateParam(value_and_index_state_spec)): - k, v = kv - index.add(1) - value_and_index.write('%s:%s' % (v, index.read())) - yield k, v, index.read(), value_and_index.read() - - inputs = [('A', 'a')] * 2 + [('B', 'b')] * 3 - expected = [('A', 'a', 1, 'a:1'), ('A', 'a', 2, 'a:2'), - ('B', 'b', 1, 'b:1'), ('B', 'b', 2, 'b:2'), - ('B', 'b', 3, 'b:3')] - - with self.create_pipeline() as p: - assert_that( - p | beam.Create(inputs) | beam.ParDo(AddIndex()), equal_to(expected)) - - @unittest.skip('TestStream not yet supported') - def test_teststream_pardo_timers(self): - timer_spec = userstate.TimerSpec('timer', userstate.TimeDomain.WATERMARK) - - class TimerDoFn(beam.DoFn): - def process(self, element, timer=beam.DoFn.TimerParam(timer_spec)): - unused_key, ts = element - timer.set(ts) - timer.set(2 * ts) - - @userstate.on_timer(timer_spec) - def process_timer(self): - yield 'fired' - - ts = ( - TestStream().add_elements([('k1', 10)]) # Set timer for 20 - .advance_watermark_to(100).add_elements([('k2', 100) - ]) # Set timer for 200 - .advance_watermark_to(1000)) - - with self.create_pipeline() as p: - _ = ( - p - | ts - | beam.ParDo(TimerDoFn()) - | beam.Map(lambda x, ts=beam.DoFn.TimestampParam: (x, ts))) - - #expected = [('fired', ts) for ts in (20, 200)] - #assert_that(actual, equal_to(expected)) - - @unittest.skip('Timers not yet supported') - def test_pardo_timers(self): - timer_spec = userstate.TimerSpec('timer', userstate.TimeDomain.WATERMARK) - state_spec = userstate.CombiningValueStateSpec('num_called', sum) - - class TimerDoFn(beam.DoFn): - def process(self, element, timer=beam.DoFn.TimerParam(timer_spec)): - unused_key, ts = element - timer.set(ts) - timer.set(2 * ts) - - @userstate.on_timer(timer_spec) - def process_timer( - self, - ts=beam.DoFn.TimestampParam, - timer=beam.DoFn.TimerParam(timer_spec), - state=beam.DoFn.StateParam(state_spec)): - if state.read() == 0: - state.add(1) - timer.set(timestamp.Timestamp(micros=2 * ts.micros)) - yield 'fired' - - with self.create_pipeline() as p: - actual = ( - p - | beam.Create([('k1', 10), ('k2', 100)]) - | beam.ParDo(TimerDoFn()) - | beam.Map(lambda x, ts=beam.DoFn.TimestampParam: (x, ts))) - - expected = [('fired', ts) for ts in (20, 200, 40, 400)] - assert_that(actual, equal_to(expected)) - - @unittest.skip('Timers not yet supported') - def test_pardo_timers_clear(self): - timer_spec = userstate.TimerSpec('timer', userstate.TimeDomain.WATERMARK) - clear_timer_spec = userstate.TimerSpec( - 'clear_timer', userstate.TimeDomain.WATERMARK) - - class TimerDoFn(beam.DoFn): - def process( - self, - element, - timer=beam.DoFn.TimerParam(timer_spec), - clear_timer=beam.DoFn.TimerParam(clear_timer_spec)): - unused_key, ts = element - timer.set(ts) - timer.set(2 * ts) - clear_timer.set(ts) - clear_timer.clear() - - @userstate.on_timer(timer_spec) - def process_timer(self): - yield 'fired' - - @userstate.on_timer(clear_timer_spec) - def process_clear_timer(self): - yield 'should not fire' - - with self.create_pipeline() as p: - actual = ( - p - | beam.Create([('k1', 10), ('k2', 100)]) - | beam.ParDo(TimerDoFn()) - | beam.Map(lambda x, ts=beam.DoFn.TimestampParam: (x, ts))) - - expected = [('fired', ts) for ts in (20, 200)] - assert_that(actual, equal_to(expected)) - - @unittest.skip('Timers not yet supported') - def test_pardo_state_timers(self): - self._run_pardo_state_timers(windowed=False) - - @unittest.skip('Timers not yet supported') - def test_pardo_state_timers_non_standard_coder(self): - self._run_pardo_state_timers(windowed=False, key_type=Any) - - @unittest.skip('Timers not yet supported') - def test_windowed_pardo_state_timers(self): - self._run_pardo_state_timers(windowed=True) - - def _run_pardo_state_timers(self, windowed, key_type=None): - """ - :param windowed: If True, uses an interval window, otherwise a global window - :param key_type: Allows to override the inferred key type. This is useful to - test the use of non-standard coders, e.g. Python's FastPrimitivesCoder. - """ - state_spec = userstate.BagStateSpec('state', beam.coders.StrUtf8Coder()) - timer_spec = userstate.TimerSpec('timer', userstate.TimeDomain.WATERMARK) - elements = list('abcdefgh') - key = 'key' - buffer_size = 3 - - class BufferDoFn(beam.DoFn): - def process( - self, - kv, - ts=beam.DoFn.TimestampParam, - timer=beam.DoFn.TimerParam(timer_spec), - state=beam.DoFn.StateParam(state_spec)): - _, element = kv - state.add(element) - buffer = state.read() - # For real use, we'd keep track of this size separately. - if len(list(buffer)) >= 3: - state.clear() - yield buffer +class RayFnApiRunnerTest(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + if not ray.is_initialized(): + ray.init(local_mode=True) + + def create_pipeline(self, is_drain=False): + return beam.Pipeline( + runner=ray_beam_runner.portability.ray_fn_runner.RayFnApiRunner() + ) + + def test_assert_that(self): + with self.assertRaisesRegex(Exception, "Failed assert"): + with self.create_pipeline() as p: + assert_that(p | beam.Create(["a", "b"]), equal_to(["a"])) + + def test_create(self): + with self.create_pipeline() as p: + assert_that(p | beam.Create(["a", "b"]), equal_to(["a", "b"])) + + def test_pardo(self): + with self.create_pipeline() as p: + res = ( + p + | beam.Create(["a", "bc"]) + | beam.Map(lambda e: e * 2) + | beam.Map(lambda e: e + "x") + ) + assert_that(res, equal_to(["aax", "bcbcx"])) + + def test_pardo_side_outputs(self): + def tee(elem, *tags): + for tag in tags: + if tag in elem: + yield beam.pvalue.TaggedOutput(tag, elem) + + with self.create_pipeline() as p: + xy = ( + p + | "Create" >> beam.Create(["x", "y", "xy"]) + | beam.FlatMap(tee, "x", "y").with_outputs() + ) + assert_that(xy.x, equal_to(["x", "xy"]), label="x") + assert_that(xy.y, equal_to(["y", "xy"]), label="y") + + def test_pardo_side_and_main_outputs(self): + def even_odd(elem): + yield elem + yield beam.pvalue.TaggedOutput("odd" if elem % 2 else "even", elem) + + with self.create_pipeline() as p: + ints = p | beam.Create([1, 2, 3]) + named = ints | "named" >> beam.FlatMap(even_odd).with_outputs( + "even", "odd", main="all" + ) + assert_that(named.all, equal_to([1, 2, 3]), label="named.all") + assert_that(named.even, equal_to([2]), label="named.even") + assert_that(named.odd, equal_to([1, 3]), label="named.odd") + + unnamed = ints | "unnamed" >> beam.FlatMap(even_odd).with_outputs() + unnamed[None] | beam.Map(id) # pylint: disable=expression-not-assigned + assert_that(unnamed[None], equal_to([1, 2, 3]), label="unnamed.all") + assert_that(unnamed.even, equal_to([2]), label="unnamed.even") + assert_that(unnamed.odd, equal_to([1, 3]), label="unnamed.odd") + + @unittest.skip("Side inputs not yet supported") + def test_pardo_side_inputs(self): + def cross_product(elem, sides): + for side in sides: + yield elem, side + + with self.create_pipeline() as p: + main = p | "main" >> beam.Create(["a", "b", "c"]) + side = p | "side" >> beam.Create(["x", "y"]) + assert_that( + main | beam.FlatMap(cross_product, beam.pvalue.AsList(side)), + equal_to( + [ + ("a", "x"), + ("b", "x"), + ("c", "x"), + ("a", "y"), + ("b", "y"), + ("c", "y"), + ] + ), + ) + + @unittest.skip("Side inputs not yet supported") + def test_pardo_side_input_dependencies(self): + with self.create_pipeline() as p: + inputs = [p | beam.Create([None])] + for k in range(1, 10): + inputs.append( + inputs[0] + | beam.ParDo( + ExpectingSideInputsFn(f"Do{k}"), + *[beam.pvalue.AsList(inputs[s]) for s in range(1, k)], + ) + ) + + @unittest.skip("Side inputs not yet supported") + def test_pardo_side_input_sparse_dependencies(self): + with self.create_pipeline() as p: + inputs = [] + + def choose_input(s): + return inputs[(389 + s * 5077) % len(inputs)] + + for k in range(30): + num_inputs = int((k * k % 16) ** 0.5) + if num_inputs == 0: + inputs.append(p | f"Create{k}" >> beam.Create([f"Create{k}"])) + else: + inputs.append( + choose_input(0) + | beam.ParDo( + ExpectingSideInputsFn(f"Do{k}"), + *[ + beam.pvalue.AsList(choose_input(s)) + for s in range(1, num_inputs) + ], + ) + ) + + @unittest.skip("Side inputs not yet supported") + def test_pardo_windowed_side_inputs(self): + with self.create_pipeline() as p: + # Now with some windowing. + pcoll = ( + p + | beam.Create(list(range(10))) + | beam.Map(lambda t: window.TimestampedValue(t, t)) + ) + # Intentionally choosing non-aligned windows to highlight the transition. + main = pcoll | "WindowMain" >> beam.WindowInto(window.FixedWindows(5)) + side = pcoll | "WindowSide" >> beam.WindowInto(window.FixedWindows(7)) + res = main | beam.Map(lambda x, s: (x, sorted(s)), beam.pvalue.AsList(side)) + assert_that( + res, + equal_to( + [ + # The window [0, 5) maps to the window [0, 7). + (0, list(range(7))), + (1, list(range(7))), + (2, list(range(7))), + (3, list(range(7))), + (4, list(range(7))), + # The window [5, 10) maps to the window [7, 14). + (5, list(range(7, 10))), + (6, list(range(7, 10))), + (7, list(range(7, 10))), + (8, list(range(7, 10))), + (9, list(range(7, 10))), + ] + ), + label="windowed", + ) + + @unittest.skip("Side inputs not yet supported") + def test_flattened_side_input(self, with_transcoding=True): + with self.create_pipeline() as p: + main = p | "main" >> beam.Create([None]) + side1 = p | "side1" >> beam.Create([("a", 1)]) + side2 = p | "side2" >> beam.Create([("b", 2)]) + if with_transcoding: + # Also test non-matching coder types (transcoding required) + third_element = [("another_type")] + else: + third_element = [("b", 3)] + side3 = p | "side3" >> beam.Create(third_element) + side = (side1, side2) | beam.Flatten() + assert_that( + main | beam.Map(lambda a, b: (a, b), beam.pvalue.AsDict(side)), + equal_to([(None, {"a": 1, "b": 2})]), + label="CheckFlattenAsSideInput", + ) + assert_that( + (side, side3) | "FlattenAfter" >> beam.Flatten(), + equal_to([("a", 1), ("b", 2)] + third_element), + label="CheckFlattenOfSideInput", + ) + + @unittest.skip("Side inputs not yet supported") + def test_gbk_side_input(self): + with self.create_pipeline() as p: + main = p | "main" >> beam.Create([None]) + side = p | "side" >> beam.Create([("a", 1)]) | beam.GroupByKey() + assert_that( + main | beam.Map(lambda a, b: (a, b), beam.pvalue.AsDict(side)), + equal_to([(None, {"a": [1]})]), + ) + + @unittest.skip("Side inputs not yet supported") + def test_multimap_side_input(self): + with self.create_pipeline() as p: + main = p | "main" >> beam.Create(["a", "b"]) + side = p | "side" >> beam.Create([("a", 1), ("b", 2), ("a", 3)]) + assert_that( + main + | beam.Map( + lambda k, d: (k, sorted(d[k])), beam.pvalue.AsMultiMap(side) + ), + equal_to([("a", [1, 3]), ("b", [2])]), + ) + + @unittest.skip("Side inputs not yet supported") + def test_multimap_multiside_input(self): + # A test where two transforms in the same stage consume the same PCollection + # twice as side input. + with self.create_pipeline() as p: + main = p | "main" >> beam.Create(["a", "b"]) + side = p | "side" >> beam.Create([("a", 1), ("b", 2), ("a", 3)]) + assert_that( + main + | "first map" + >> beam.Map( + lambda k, d, l: (k, sorted(d[k]), sorted([e[1] for e in l])), + beam.pvalue.AsMultiMap(side), + beam.pvalue.AsList(side), + ) + | "second map" + >> beam.Map( + lambda k, d, l: (k[0], sorted(d[k[0]]), sorted([e[1] for e in l])), + beam.pvalue.AsMultiMap(side), + beam.pvalue.AsList(side), + ), + equal_to([("a", [1, 3], [1, 2, 3]), ("b", [2], [1, 2, 3])]), + ) + + @unittest.skip("Side inputs not yet supported") + def test_multimap_side_input_type_coercion(self): + with self.create_pipeline() as p: + main = p | "main" >> beam.Create(["a", "b"]) + # The type of this side-input is forced to Any (overriding type + # inference). Without type coercion to Tuple[Any, Any], the usage of this + # side-input in AsMultiMap() below should fail. + side = p | "side" >> beam.Create( + [("a", 1), ("b", 2), ("a", 3)] + ).with_output_types(typing.Any) + assert_that( + main + | beam.Map( + lambda k, d: (k, sorted(d[k])), beam.pvalue.AsMultiMap(side) + ), + equal_to([("a", [1, 3]), ("b", [2])]), + ) + + @unittest.skip("Side inputs not yet supported") + def test_pardo_unfusable_side_inputs(self): + def cross_product(elem, sides): + for side in sides: + yield elem, side + + with self.create_pipeline() as p: + pcoll = p | beam.Create(["a", "b"]) + assert_that( + pcoll | beam.FlatMap(cross_product, beam.pvalue.AsList(pcoll)), + equal_to([("a", "a"), ("a", "b"), ("b", "a"), ("b", "b")]), + ) + + with self.create_pipeline() as p: + pcoll = p | beam.Create(["a", "b"]) + derived = ( + (pcoll,) + | beam.Flatten() + | beam.Map(lambda x: (x, x)) + | beam.GroupByKey() + | "Unkey" >> beam.Map(lambda kv: kv[0]) + ) + assert_that( + pcoll | beam.FlatMap(cross_product, beam.pvalue.AsList(derived)), + equal_to([("a", "a"), ("a", "b"), ("b", "a"), ("b", "b")]), + ) + + @unittest.skip("State not yet supported") + def test_pardo_state_only(self): + index_state_spec = userstate.CombiningValueStateSpec("index", sum) + value_and_index_state_spec = userstate.ReadModifyWriteStateSpec( + "value:index", StrUtf8Coder() + ) + + class AddIndex(beam.DoFn): + def process( + self, + kv, + index=beam.DoFn.StateParam(index_state_spec), + value_and_index=beam.DoFn.StateParam(value_and_index_state_spec), + ): + k, v = kv + index.add(1) + value_and_index.write("%s:%s" % (v, index.read())) + yield k, v, index.read(), value_and_index.read() + + inputs = [("A", "a")] * 2 + [("B", "b")] * 3 + expected = [ + ("A", "a", 1, "a:1"), + ("A", "a", 2, "a:2"), + ("B", "b", 1, "b:1"), + ("B", "b", 2, "b:2"), + ("B", "b", 3, "b:3"), + ] + + with self.create_pipeline() as p: + assert_that( + p | beam.Create(inputs) | beam.ParDo(AddIndex()), equal_to(expected) + ) + + @unittest.skip("TestStream not yet supported") + def test_teststream_pardo_timers(self): + timer_spec = userstate.TimerSpec("timer", userstate.TimeDomain.WATERMARK) + + class TimerDoFn(beam.DoFn): + def process(self, element, timer=beam.DoFn.TimerParam(timer_spec)): + unused_key, ts = element + timer.set(ts) + timer.set(2 * ts) + + @userstate.on_timer(timer_spec) + def process_timer(self): + yield "fired" + + ts = ( + TestStream() + .add_elements([("k1", 10)]) # Set timer for 20 + .advance_watermark_to(100) + .add_elements([("k2", 100)]) # Set timer for 200 + .advance_watermark_to(1000) + ) + + with self.create_pipeline() as p: + _ = ( + p + | ts + | beam.ParDo(TimerDoFn()) + | beam.Map(lambda x, ts=beam.DoFn.TimestampParam: (x, ts)) + ) + + # expected = [('fired', ts) for ts in (20, 200)] + # assert_that(actual, equal_to(expected)) + + @unittest.skip("Timers not yet supported") + def test_pardo_timers(self): + timer_spec = userstate.TimerSpec("timer", userstate.TimeDomain.WATERMARK) + state_spec = userstate.CombiningValueStateSpec("num_called", sum) + + class TimerDoFn(beam.DoFn): + def process(self, element, timer=beam.DoFn.TimerParam(timer_spec)): + unused_key, ts = element + timer.set(ts) + timer.set(2 * ts) + + @userstate.on_timer(timer_spec) + def process_timer( + self, + ts=beam.DoFn.TimestampParam, + timer=beam.DoFn.TimerParam(timer_spec), + state=beam.DoFn.StateParam(state_spec), + ): + if state.read() == 0: + state.add(1) + timer.set(timestamp.Timestamp(micros=2 * ts.micros)) + yield "fired" + + with self.create_pipeline() as p: + actual = ( + p + | beam.Create([("k1", 10), ("k2", 100)]) + | beam.ParDo(TimerDoFn()) + | beam.Map(lambda x, ts=beam.DoFn.TimestampParam: (x, ts)) + ) + + expected = [("fired", ts) for ts in (20, 200, 40, 400)] + assert_that(actual, equal_to(expected)) + + @unittest.skip("Timers not yet supported") + def test_pardo_timers_clear(self): + timer_spec = userstate.TimerSpec("timer", userstate.TimeDomain.WATERMARK) + clear_timer_spec = userstate.TimerSpec( + "clear_timer", userstate.TimeDomain.WATERMARK + ) + + class TimerDoFn(beam.DoFn): + def process( + self, + element, + timer=beam.DoFn.TimerParam(timer_spec), + clear_timer=beam.DoFn.TimerParam(clear_timer_spec), + ): + unused_key, ts = element + timer.set(ts) + timer.set(2 * ts) + clear_timer.set(ts) + clear_timer.clear() + + @userstate.on_timer(timer_spec) + def process_timer(self): + yield "fired" + + @userstate.on_timer(clear_timer_spec) + def process_clear_timer(self): + yield "should not fire" + + with self.create_pipeline() as p: + actual = ( + p + | beam.Create([("k1", 10), ("k2", 100)]) + | beam.ParDo(TimerDoFn()) + | beam.Map(lambda x, ts=beam.DoFn.TimestampParam: (x, ts)) + ) + + expected = [("fired", ts) for ts in (20, 200)] + assert_that(actual, equal_to(expected)) + + @unittest.skip("Timers not yet supported") + def test_pardo_state_timers(self): + self._run_pardo_state_timers(windowed=False) + + @unittest.skip("Timers not yet supported") + def test_pardo_state_timers_non_standard_coder(self): + self._run_pardo_state_timers(windowed=False, key_type=Any) + + @unittest.skip("Timers not yet supported") + def test_windowed_pardo_state_timers(self): + self._run_pardo_state_timers(windowed=True) + + def _run_pardo_state_timers(self, windowed, key_type=None): + """ + :param windowed: If True, uses an interval window, otherwise a global window + :param key_type: Allows to override the inferred key type. This is useful to + test the use of non-standard coders, e.g. Python's FastPrimitivesCoder. + """ + state_spec = userstate.BagStateSpec("state", beam.coders.StrUtf8Coder()) + timer_spec = userstate.TimerSpec("timer", userstate.TimeDomain.WATERMARK) + elements = list("abcdefgh") + key = "key" + buffer_size = 3 + + class BufferDoFn(beam.DoFn): + def process( + self, + kv, + ts=beam.DoFn.TimestampParam, + timer=beam.DoFn.TimerParam(timer_spec), + state=beam.DoFn.StateParam(state_spec), + ): + _, element = kv + state.add(element) + buffer = state.read() + # For real use, we'd keep track of this size separately. + if len(list(buffer)) >= 3: + state.clear() + yield buffer + else: + timer.set(ts + 1) + + @userstate.on_timer(timer_spec) + def process_timer(self, state=beam.DoFn.StateParam(state_spec)): + buffer = state.read() + state.clear() + yield buffer + + def is_buffered_correctly(actual): + # Pickling self in the closure for asserts gives errors (only on jenkins). + self = RayFnApiRunnerTest("__init__") + # Acutal should be a grouping of the inputs into batches of size + # at most buffer_size, but the actual batching is nondeterministic + # based on ordering and trigger firing timing. + self.assertEqual(sorted(sum((list(b) for b in actual), [])), elements) + self.assertEqual(max(len(list(buffer)) for buffer in actual), buffer_size) + if windowed: + # Elements were assigned to windows based on their parity. + # Assert that each grouping consists of elements belonging to the + # same window to ensure states and timers were properly partitioned. + for b in actual: + parity = set(ord(e) % 2 for e in b) + self.assertEqual(1, len(parity), b) + + with self.create_pipeline() as p: + actual = ( + p + | beam.Create(elements) + # Send even and odd elements to different windows. + | beam.Map(lambda e: window.TimestampedValue(e, ord(e) % 2)) + | beam.WindowInto( + window.FixedWindows(1) if windowed else window.GlobalWindows() + ) + | beam.Map(lambda x: (key, x)).with_output_types( + Tuple[key_type if key_type else type(key), Any] + ) + | beam.ParDo(BufferDoFn()) + ) + + assert_that(actual, is_buffered_correctly) + + @unittest.skip("Timers not yet supported") + def test_pardo_dynamic_timer(self): + class DynamicTimerDoFn(beam.DoFn): + dynamic_timer_spec = userstate.TimerSpec( + "dynamic_timer", userstate.TimeDomain.WATERMARK + ) + + def process( + self, element, dynamic_timer=beam.DoFn.TimerParam(dynamic_timer_spec) + ): + dynamic_timer.set(element[1], dynamic_timer_tag=element[0]) + + @userstate.on_timer(dynamic_timer_spec) + def dynamic_timer_callback( + self, + tag=beam.DoFn.DynamicTimerTagParam, + timestamp=beam.DoFn.TimestampParam, + ): + yield (tag, timestamp) + + with self.create_pipeline() as p: + actual = ( + p + | beam.Create([("key1", 10), ("key2", 20), ("key3", 30)]) + | beam.ParDo(DynamicTimerDoFn()) + ) + assert_that(actual, equal_to([("key1", 10), ("key2", 20), ("key3", 30)])) + + def test_sdf(self): + class ExpandingStringsDoFn(beam.DoFn): + def process( + self, + element, + restriction_tracker=beam.DoFn.RestrictionParam(ExpandStringsProvider()), + ): + assert isinstance(restriction_tracker, RestrictionTrackerView) + cur = restriction_tracker.current_restriction().start + while restriction_tracker.try_claim(cur): + yield element[cur] + cur += 1 + + with self.create_pipeline() as p: + data = ["abc", "defghijklmno", "pqrstuv", "wxyz"] + actual = p | beam.Create(data) | beam.ParDo(ExpandingStringsDoFn()) + assert_that(actual, equal_to(list("".join(data)))) + + def test_sdf_with_dofn_as_restriction_provider(self): + class ExpandingStringsDoFn(beam.DoFn, ExpandStringsProvider): + def process( + self, element, restriction_tracker=beam.DoFn.RestrictionParam() + ): + assert isinstance(restriction_tracker, RestrictionTrackerView) + cur = restriction_tracker.current_restriction().start + while restriction_tracker.try_claim(cur): + yield element[cur] + cur += 1 + + with self.create_pipeline() as p: + data = ["abc", "defghijklmno", "pqrstuv", "wxyz"] + actual = p | beam.Create(data) | beam.ParDo(ExpandingStringsDoFn()) + assert_that(actual, equal_to(list("".join(data)))) + + def test_sdf_with_check_done_failed(self): + class ExpandingStringsDoFn(beam.DoFn): + def process( + self, + element, + restriction_tracker=beam.DoFn.RestrictionParam(ExpandStringsProvider()), + ): + assert isinstance(restriction_tracker, RestrictionTrackerView) + cur = restriction_tracker.current_restriction().start + while restriction_tracker.try_claim(cur): + yield element[cur] + cur += 1 + return + + with self.assertRaises(Exception): + with self.create_pipeline() as p: + data = ["abc", "defghijklmno", "pqrstuv", "wxyz"] + _ = p | beam.Create(data) | beam.ParDo(ExpandingStringsDoFn()) + + @unittest.skip("Watermark tracking not yet supported not yet supported") + def test_sdf_with_watermark_tracking(self): + class ExpandingStringsDoFn(beam.DoFn): + def process( + self, + element, + restriction_tracker=beam.DoFn.RestrictionParam(ExpandStringsProvider()), + watermark_estimator=beam.DoFn.WatermarkEstimatorParam( + ManualWatermarkEstimator.default_provider() + ), + ): + cur = restriction_tracker.current_restriction().start + while restriction_tracker.try_claim(cur): + watermark_estimator.set_watermark(timestamp.Timestamp(cur)) + assert ( + watermark_estimator.current_watermark() + == timestamp.Timestamp(cur) + ) + yield element[cur] + if cur % 2 == 1: + restriction_tracker.defer_remainder( + timestamp.Duration(micros=5) + ) + return + cur += 1 + + with self.create_pipeline() as p: + data = ["abc", "defghijklmno", "pqrstuv", "wxyz"] + actual = p | beam.Create(data) | beam.ParDo(ExpandingStringsDoFn()) + assert_that(actual, equal_to(list("".join(data)))) + + @unittest.skip("SDF not yet supported") + def test_sdf_with_dofn_as_watermark_estimator(self): + class ExpandingStringsDoFn(beam.DoFn, beam.WatermarkEstimatorProvider): + def initial_estimator_state(self, element, restriction): + return None + + def create_watermark_estimator(self, state): + return beam.io.watermark_estimators.ManualWatermarkEstimator(state) + + def process( + self, + element, + restriction_tracker=beam.DoFn.RestrictionParam(ExpandStringsProvider()), + watermark_estimator=beam.DoFn.WatermarkEstimatorParam( + ManualWatermarkEstimator.default_provider() + ), + ): + cur = restriction_tracker.current_restriction().start + while restriction_tracker.try_claim(cur): + watermark_estimator.set_watermark(timestamp.Timestamp(cur)) + assert ( + watermark_estimator.current_watermark() + == timestamp.Timestamp(cur) + ) + yield element[cur] + if cur % 2 == 1: + restriction_tracker.defer_remainder( + timestamp.Duration(micros=5) + ) + return + cur += 1 + + with self.create_pipeline() as p: + data = ["abc", "defghijklmno", "pqrstuv", "wxyz"] + actual = p | beam.Create(data) | beam.ParDo(ExpandingStringsDoFn()) + assert_that(actual, equal_to(list("".join(data)))) + + def run_sdf_initiated_checkpointing(self, is_drain=False): + counter = beam.metrics.Metrics.counter("ns", "my_counter") + + class ExpandStringsDoFn(beam.DoFn): + def process( + self, + element, + restriction_tracker=beam.DoFn.RestrictionParam(ExpandStringsProvider()), + ): + assert isinstance(restriction_tracker, RestrictionTrackerView) + cur = restriction_tracker.current_restriction().start + while restriction_tracker.try_claim(cur): + counter.inc() + yield element[cur] + if cur % 2 == 1: + restriction_tracker.defer_remainder() + return + cur += 1 + + with self.create_pipeline(is_drain=is_drain) as p: + data = ["abc", "defghijklmno", "pqrstuv", "wxyz"] + actual = p | beam.Create(data) | beam.ParDo(ExpandStringsDoFn()) + + assert_that(actual, equal_to(list("".join(data)))) + + return # Metrics not yet supported! + # TODO: Enable following code section + # if isinstance(p.runner, fn_api_runner.FnApiRunner): + # res = p.runner._latest_run_result + # counters = res.metrics().query( + # beam.metrics.MetricsFilter().with_name('my_counter'))['counters'] + # self.assertEqual(1, len(counters)) + # self.assertEqual(counters[0].committed, len(''.join(data))) + + def test_sdf_with_sdf_initiated_checkpointing(self): + self.run_sdf_initiated_checkpointing(is_drain=False) + + @unittest.skip("SDF not yet supported") + def test_draining_sdf_with_sdf_initiated_checkpointing(self): + self.run_sdf_initiated_checkpointing(is_drain=True) + + @unittest.skip("SDF not yet supported") + def test_sdf_default_truncate_when_bounded(self): + class SimleSDF(beam.DoFn): + def process( + self, + element, + restriction_tracker=beam.DoFn.RestrictionParam( + OffsetRangeProvider(use_bounded_offset_range=True) + ), + ): + assert isinstance(restriction_tracker, RestrictionTrackerView) + cur = restriction_tracker.current_restriction().start + while restriction_tracker.try_claim(cur): + yield cur + cur += 1 + + with self.create_pipeline(is_drain=True) as p: + actual = p | beam.Create([10]) | beam.ParDo(SimleSDF()) + assert_that(actual, equal_to(range(10))) + + @unittest.skip("SDF not yet supported") + def test_sdf_default_truncate_when_unbounded(self): + class SimleSDF(beam.DoFn): + def process( + self, + element, + restriction_tracker=beam.DoFn.RestrictionParam( + OffsetRangeProvider(use_bounded_offset_range=False) + ), + ): + assert isinstance(restriction_tracker, RestrictionTrackerView) + cur = restriction_tracker.current_restriction().start + while restriction_tracker.try_claim(cur): + yield cur + cur += 1 + + with self.create_pipeline(is_drain=True) as p: + actual = p | beam.Create([10]) | beam.ParDo(SimleSDF()) + assert_that(actual, equal_to([])) + + @unittest.skip("SDF not yet supported") + def test_sdf_with_truncate(self): + class SimleSDF(beam.DoFn): + def process( + self, + element, + restriction_tracker=beam.DoFn.RestrictionParam( + OffsetRangeProviderWithTruncate() + ), + ): + assert isinstance(restriction_tracker, RestrictionTrackerView) + cur = restriction_tracker.current_restriction().start + while restriction_tracker.try_claim(cur): + yield cur + cur += 1 + + with self.create_pipeline(is_drain=True) as p: + actual = p | beam.Create([10]) | beam.ParDo(SimleSDF()) + assert_that(actual, equal_to(range(5))) + + def test_group_by_key(self): + with self.create_pipeline() as p: + res = ( + p + | beam.Create([("a", 1), ("a", 2), ("b", 3)]) + | beam.GroupByKey() + | beam.Map(lambda k_vs: (k_vs[0], sorted(k_vs[1]))) + ) + assert_that(res, equal_to([("a", [1, 2]), ("b", [3])])) + + # Runners may special case the Reshuffle transform urn. + def test_reshuffle(self): + with self.create_pipeline() as p: + assert_that( + p | beam.Create([1, 2, 3]) | beam.Reshuffle(), equal_to([1, 2, 3]) + ) + + def test_flatten(self, with_transcoding=True): + with self.create_pipeline() as p: + if with_transcoding: + # Additional element which does not match with the first type + additional = [ord("d")] + else: + additional = ["d"] + res = ( + p | "a" >> beam.Create(["a"]), + p | "bc" >> beam.Create(["b", "c"]), + p | "d" >> beam.Create(additional), + ) | beam.Flatten() + assert_that(res, equal_to(["a", "b", "c"] + additional)) + + def test_flatten_same_pcollections(self, with_transcoding=True): + with self.create_pipeline() as p: + pc = p | beam.Create(["a", "b"]) + assert_that((pc, pc, pc) | beam.Flatten(), equal_to(["a", "b"] * 3)) + + @unittest.skip("Combiner lifting not yet supported") + def test_combine_per_key(self): + with self.create_pipeline() as p: + res = ( + p + | beam.Create([("a", 1), ("a", 2), ("b", 3)]) + | beam.CombinePerKey(beam.combiners.MeanCombineFn()) + ) + assert_that(res, equal_to([("a", 1.5), ("b", 3.0)])) + + def test_read(self): + # Can't use NamedTemporaryFile as a context + # due to https://bugs.python.org/issue14243 + temp_file = tempfile.NamedTemporaryFile(delete=False) + try: + temp_file.write(b"a\nb\nc") + temp_file.close() + with self.create_pipeline() as p: + assert_that( + p | beam.io.ReadFromText(temp_file.name), equal_to(["a", "b", "c"]) + ) + finally: + os.unlink(temp_file.name) + + def test_windowing(self): + with self.create_pipeline() as p: + res = ( + p + | beam.Create([1, 2, 100, 101, 102]) + | beam.Map(lambda t: window.TimestampedValue(("k", t), t)) + | beam.WindowInto(beam.transforms.window.Sessions(10)) + | beam.GroupByKey() + | beam.Map(lambda k_vs1: (k_vs1[0], sorted(k_vs1[1]))) + ) + assert_that(res, equal_to([("k", [1, 2]), ("k", [100, 101, 102])])) + + def test_custom_merging_window(self): + with self.create_pipeline() as p: + res = ( + p + | beam.Create([1, 2, 100, 101, 102]) + | beam.Map(lambda t: window.TimestampedValue(("k", t), t)) + | beam.WindowInto(CustomMergingWindowFn()) + | beam.GroupByKey() + | beam.Map(lambda k_vs1: (k_vs1[0], sorted(k_vs1[1]))) + ) + assert_that(res, equal_to([("k", [1]), ("k", [101]), ("k", [2, 100, 102])])) + gc.collect() + from apache_beam.runners.portability.fn_api_runner.execution import ( + GenericMergingWindowFn, + ) + + self.assertEqual(GenericMergingWindowFn._HANDLES, {}) + + @unittest.skip("BEAM-9119: test is flaky") + def test_large_elements(self): + with self.create_pipeline() as p: + big = ( + p + | beam.Create(["a", "a", "b"]) + | beam.Map(lambda x: (x, x * data_plane._DEFAULT_SIZE_FLUSH_THRESHOLD)) + ) + + side_input_res = big | beam.Map( + lambda x, side: (x[0], side.count(x[0])), + beam.pvalue.AsList(big | beam.Map(lambda x: x[0])), + ) + assert_that( + side_input_res, equal_to([("a", 2), ("a", 2), ("b", 1)]), label="side" + ) + + gbk_res = big | beam.GroupByKey() | beam.Map(lambda x: x[0]) + assert_that(gbk_res, equal_to(["a", "b"]), label="gbk") + + @unittest.skip("Error messages need to improve") + def test_error_message_includes_stage(self): + with self.assertRaises(BaseException) as e_cm: + with self.create_pipeline() as p: + + def raise_error(x): + raise RuntimeError("x") + + # pylint: disable=expression-not-assigned + ( + p + | beam.Create(["a", "b"]) + | "StageA" >> beam.Map(lambda x: x) + | "StageB" >> beam.Map(lambda x: x) + | "StageC" >> beam.Map(raise_error) + | "StageD" >> beam.Map(lambda x: x) + ) + message = e_cm.exception.args[0] + self.assertIn("StageC", message) + self.assertNotIn("StageB", message) + + def test_error_traceback_includes_user_code(self): + def first(x): + return second(x) + + def second(x): + return third(x) + + def third(x): + raise ValueError("x") + + try: + with self.create_pipeline() as p: + p | beam.Create([0]) | beam.Map( + first + ) # pylint: disable=expression-not-assigned + except Exception: # pylint: disable=broad-except + message = traceback.format_exc() else: - timer.set(ts + 1) - - @userstate.on_timer(timer_spec) - def process_timer(self, state=beam.DoFn.StateParam(state_spec)): - buffer = state.read() - state.clear() - yield buffer - - def is_buffered_correctly(actual): - # Pickling self in the closure for asserts gives errors (only on jenkins). - self = FnApiRunnerTest('__init__') - # Acutal should be a grouping of the inputs into batches of size - # at most buffer_size, but the actual batching is nondeterministic - # based on ordering and trigger firing timing. - self.assertEqual(sorted(sum((list(b) for b in actual), [])), elements) - self.assertEqual(max(len(list(buffer)) for buffer in actual), buffer_size) - if windowed: - # Elements were assigned to windows based on their parity. - # Assert that each grouping consists of elements belonging to the - # same window to ensure states and timers were properly partitioned. - for b in actual: - parity = set(ord(e) % 2 for e in b) - self.assertEqual(1, len(parity), b) - - with self.create_pipeline() as p: - actual = ( - p - | beam.Create(elements) - # Send even and odd elements to different windows. - | beam.Map(lambda e: window.TimestampedValue(e, ord(e) % 2)) - | beam.WindowInto( - window.FixedWindows(1) if windowed else window.GlobalWindows()) - | beam.Map(lambda x: (key, x)).with_output_types( - Tuple[key_type if key_type else type(key), Any]) - | beam.ParDo(BufferDoFn())) - - assert_that(actual, is_buffered_correctly) - - @unittest.skip('Timers not yet supported') - def test_pardo_dynamic_timer(self): - class DynamicTimerDoFn(beam.DoFn): - dynamic_timer_spec = userstate.TimerSpec( - 'dynamic_timer', userstate.TimeDomain.WATERMARK) - - def process( - self, element, - dynamic_timer=beam.DoFn.TimerParam(dynamic_timer_spec)): - dynamic_timer.set(element[1], dynamic_timer_tag=element[0]) - - @userstate.on_timer(dynamic_timer_spec) - def dynamic_timer_callback( - self, - tag=beam.DoFn.DynamicTimerTagParam, - timestamp=beam.DoFn.TimestampParam): - yield (tag, timestamp) - - with self.create_pipeline() as p: - actual = ( - p - | beam.Create([('key1', 10), ('key2', 20), ('key3', 30)]) - | beam.ParDo(DynamicTimerDoFn())) - assert_that(actual, equal_to([('key1', 10), ('key2', 20), ('key3', 30)])) - - def test_sdf(self): - class ExpandingStringsDoFn(beam.DoFn): - def process( - self, - element, - restriction_tracker=beam.DoFn.RestrictionParam( - ExpandStringsProvider())): - assert isinstance(restriction_tracker, RestrictionTrackerView) - cur = restriction_tracker.current_restriction().start - while restriction_tracker.try_claim(cur): - yield element[cur] - cur += 1 - - with self.create_pipeline() as p: - data = ['abc', 'defghijklmno', 'pqrstuv', 'wxyz'] - actual = (p | beam.Create(data) | beam.ParDo(ExpandingStringsDoFn())) - assert_that(actual, equal_to(list(''.join(data)))) - - def test_sdf_with_dofn_as_restriction_provider(self): - class ExpandingStringsDoFn(beam.DoFn, ExpandStringsProvider): - def process( - self, element, restriction_tracker=beam.DoFn.RestrictionParam()): - assert isinstance(restriction_tracker, RestrictionTrackerView) - cur = restriction_tracker.current_restriction().start - while restriction_tracker.try_claim(cur): - yield element[cur] - cur += 1 - - with self.create_pipeline() as p: - data = ['abc', 'defghijklmno', 'pqrstuv', 'wxyz'] - actual = (p | beam.Create(data) | beam.ParDo(ExpandingStringsDoFn())) - assert_that(actual, equal_to(list(''.join(data)))) - - def test_sdf_with_check_done_failed(self): - class ExpandingStringsDoFn(beam.DoFn): - def process( - self, - element, - restriction_tracker=beam.DoFn.RestrictionParam( - ExpandStringsProvider())): - assert isinstance(restriction_tracker, RestrictionTrackerView) - cur = restriction_tracker.current_restriction().start - while restriction_tracker.try_claim(cur): - yield element[cur] - cur += 1 - return - - with self.assertRaises(Exception): - with self.create_pipeline() as p: - data = ['abc', 'defghijklmno', 'pqrstuv', 'wxyz'] - _ = (p | beam.Create(data) | beam.ParDo(ExpandingStringsDoFn())) - - @unittest.skip('Watermark tracking not yet supported not yet supported') - def test_sdf_with_watermark_tracking(self): - class ExpandingStringsDoFn(beam.DoFn): - def process( - self, - element, - restriction_tracker=beam.DoFn.RestrictionParam( - ExpandStringsProvider()), - watermark_estimator=beam.DoFn.WatermarkEstimatorParam( - ManualWatermarkEstimator.default_provider())): - cur = restriction_tracker.current_restriction().start - while restriction_tracker.try_claim(cur): - watermark_estimator.set_watermark(timestamp.Timestamp(cur)) - assert ( - watermark_estimator.current_watermark() == timestamp.Timestamp( - cur)) - yield element[cur] - if cur % 2 == 1: - restriction_tracker.defer_remainder(timestamp.Duration(micros=5)) - return - cur += 1 - - with self.create_pipeline() as p: - data = ['abc', 'defghijklmno', 'pqrstuv', 'wxyz'] - actual = (p | beam.Create(data) | beam.ParDo(ExpandingStringsDoFn())) - assert_that(actual, equal_to(list(''.join(data)))) - - @unittest.skip('SDF not yet supported') - def test_sdf_with_dofn_as_watermark_estimator(self): - class ExpandingStringsDoFn(beam.DoFn, beam.WatermarkEstimatorProvider): - def initial_estimator_state(self, element, restriction): - return None + raise AssertionError("expected exception not raised") - def create_watermark_estimator(self, state): - return beam.io.watermark_estimators.ManualWatermarkEstimator(state) - - def process( - self, - element, - restriction_tracker=beam.DoFn.RestrictionParam( - ExpandStringsProvider()), - watermark_estimator=beam.DoFn.WatermarkEstimatorParam( - ManualWatermarkEstimator.default_provider())): - cur = restriction_tracker.current_restriction().start - while restriction_tracker.try_claim(cur): - watermark_estimator.set_watermark(timestamp.Timestamp(cur)) - assert ( - watermark_estimator.current_watermark() == timestamp.Timestamp( - cur)) - yield element[cur] - if cur % 2 == 1: - restriction_tracker.defer_remainder(timestamp.Duration(micros=5)) - return - cur += 1 - - with self.create_pipeline() as p: - data = ['abc', 'defghijklmno', 'pqrstuv', 'wxyz'] - actual = (p | beam.Create(data) | beam.ParDo(ExpandingStringsDoFn())) - assert_that(actual, equal_to(list(''.join(data)))) - - def run_sdf_initiated_checkpointing(self, is_drain=False): - counter = beam.metrics.Metrics.counter('ns', 'my_counter') - - class ExpandStringsDoFn(beam.DoFn): - def process( - self, - element, - restriction_tracker=beam.DoFn.RestrictionParam( - ExpandStringsProvider())): - assert isinstance(restriction_tracker, RestrictionTrackerView) - cur = restriction_tracker.current_restriction().start - while restriction_tracker.try_claim(cur): - counter.inc() - yield element[cur] - if cur % 2 == 1: - restriction_tracker.defer_remainder() - return - cur += 1 - - with self.create_pipeline(is_drain=is_drain) as p: - data = ['abc', 'defghijklmno', 'pqrstuv', 'wxyz'] - actual = (p | beam.Create(data) | beam.ParDo(ExpandStringsDoFn())) - - assert_that(actual, equal_to(list(''.join(data)))) - - return # Metrics not yet supported! - # TODO: Enable following code section - # if isinstance(p.runner, fn_api_runner.FnApiRunner): - # res = p.runner._latest_run_result - # counters = res.metrics().query( - # beam.metrics.MetricsFilter().with_name('my_counter'))['counters'] - # self.assertEqual(1, len(counters)) - # self.assertEqual(counters[0].committed, len(''.join(data))) - - def test_sdf_with_sdf_initiated_checkpointing(self): - self.run_sdf_initiated_checkpointing(is_drain=False) - - @unittest.skip('SDF not yet supported') - def test_draining_sdf_with_sdf_initiated_checkpointing(self): - self.run_sdf_initiated_checkpointing(is_drain=True) - - @unittest.skip('SDF not yet supported') - def test_sdf_default_truncate_when_bounded(self): - class SimleSDF(beam.DoFn): - def process( - self, - element, - restriction_tracker=beam.DoFn.RestrictionParam( - OffsetRangeProvider(use_bounded_offset_range=True))): - assert isinstance(restriction_tracker, RestrictionTrackerView) - cur = restriction_tracker.current_restriction().start - while restriction_tracker.try_claim(cur): - yield cur - cur += 1 - - with self.create_pipeline(is_drain=True) as p: - actual = (p | beam.Create([10]) | beam.ParDo(SimleSDF())) - assert_that(actual, equal_to(range(10))) - - @unittest.skip('SDF not yet supported') - def test_sdf_default_truncate_when_unbounded(self): - class SimleSDF(beam.DoFn): - def process( - self, - element, - restriction_tracker=beam.DoFn.RestrictionParam( - OffsetRangeProvider(use_bounded_offset_range=False))): - assert isinstance(restriction_tracker, RestrictionTrackerView) - cur = restriction_tracker.current_restriction().start - while restriction_tracker.try_claim(cur): - yield cur - cur += 1 - - with self.create_pipeline(is_drain=True) as p: - actual = (p | beam.Create([10]) | beam.ParDo(SimleSDF())) - assert_that(actual, equal_to([])) - - @unittest.skip('SDF not yet supported') - def test_sdf_with_truncate(self): - class SimleSDF(beam.DoFn): - def process( - self, - element, - restriction_tracker=beam.DoFn.RestrictionParam( - OffsetRangeProviderWithTruncate())): - assert isinstance(restriction_tracker, RestrictionTrackerView) - cur = restriction_tracker.current_restriction().start - while restriction_tracker.try_claim(cur): - yield cur - cur += 1 - - with self.create_pipeline(is_drain=True) as p: - actual = (p | beam.Create([10]) | beam.ParDo(SimleSDF())) - assert_that(actual, equal_to(range(5))) - - def test_group_by_key(self): - with self.create_pipeline() as p: - res = ( - p - | beam.Create([('a', 1), ('a', 2), ('b', 3)]) - | beam.GroupByKey() - | beam.Map(lambda k_vs: (k_vs[0], sorted(k_vs[1])))) - assert_that(res, equal_to([('a', [1, 2]), ('b', [3])])) - - # Runners may special case the Reshuffle transform urn. - def test_reshuffle(self): - with self.create_pipeline() as p: - assert_that( - p | beam.Create([1, 2, 3]) | beam.Reshuffle(), equal_to([1, 2, 3])) - - def test_flatten(self, with_transcoding=True): - with self.create_pipeline() as p: - if with_transcoding: - # Additional element which does not match with the first type - additional = [ord('d')] - else: - additional = ['d'] - res = ( - p | 'a' >> beam.Create(['a']), - p | 'bc' >> beam.Create(['b', 'c']), - p | 'd' >> beam.Create(additional)) | beam.Flatten() - assert_that(res, equal_to(['a', 'b', 'c'] + additional)) - - def test_flatten_same_pcollections(self, with_transcoding=True): - with self.create_pipeline() as p: - pc = p | beam.Create(['a', 'b']) - assert_that((pc, pc, pc) | beam.Flatten(), equal_to(['a', 'b'] * 3)) - - @unittest.skip('Combiner lifting not yet supported') - def test_combine_per_key(self): - with self.create_pipeline() as p: - res = ( - p - | beam.Create([('a', 1), ('a', 2), ('b', 3)]) - | beam.CombinePerKey(beam.combiners.MeanCombineFn())) - assert_that(res, equal_to([('a', 1.5), ('b', 3.0)])) - - def test_read(self): - # Can't use NamedTemporaryFile as a context - # due to https://bugs.python.org/issue14243 - temp_file = tempfile.NamedTemporaryFile(delete=False) - try: - temp_file.write(b'a\nb\nc') - temp_file.close() - with self.create_pipeline() as p: - assert_that( - p | beam.io.ReadFromText(temp_file.name), equal_to(['a', 'b', 'c'])) - finally: - os.unlink(temp_file.name) - - def test_windowing(self): - with self.create_pipeline() as p: - res = ( - p - | beam.Create([1, 2, 100, 101, 102]) - | beam.Map(lambda t: window.TimestampedValue(('k', t), t)) - | beam.WindowInto(beam.transforms.window.Sessions(10)) - | beam.GroupByKey() - | beam.Map(lambda k_vs1: (k_vs1[0], sorted(k_vs1[1])))) - assert_that(res, equal_to([('k', [1, 2]), ('k', [100, 101, 102])])) - - def test_custom_merging_window(self): - with self.create_pipeline() as p: - res = ( - p - | beam.Create([1, 2, 100, 101, 102]) - | beam.Map(lambda t: window.TimestampedValue(('k', t), t)) - | beam.WindowInto(CustomMergingWindowFn()) - | beam.GroupByKey() - | beam.Map(lambda k_vs1: (k_vs1[0], sorted(k_vs1[1])))) - assert_that( - res, equal_to([('k', [1]), ('k', [101]), ('k', [2, 100, 102])])) - gc.collect() - from apache_beam.runners.portability.fn_api_runner.execution import GenericMergingWindowFn - self.assertEqual(GenericMergingWindowFn._HANDLES, {}) - - @unittest.skip('BEAM-9119: test is flaky') - def test_large_elements(self): - with self.create_pipeline() as p: - big = ( - p - | beam.Create(['a', 'a', 'b']) - | - beam.Map(lambda x: (x, x * data_plane._DEFAULT_SIZE_FLUSH_THRESHOLD))) - - side_input_res = ( - big - | beam.Map( - lambda x, - side: (x[0], side.count(x[0])), - beam.pvalue.AsList(big | beam.Map(lambda x: x[0])))) - assert_that( - side_input_res, - equal_to([('a', 2), ('a', 2), ('b', 1)]), - label='side') - - gbk_res = (big | beam.GroupByKey() | beam.Map(lambda x: x[0])) - assert_that(gbk_res, equal_to(['a', 'b']), label='gbk') - - @unittest.skip('Error messages need to improve') - def test_error_message_includes_stage(self): - with self.assertRaises(BaseException) as e_cm: - with self.create_pipeline() as p: - - def raise_error(x): - raise RuntimeError('x') + self.assertIn("first", message) + self.assertIn("second", message) + self.assertIn("third", message) - # pylint: disable=expression-not-assigned - ( - p - | beam.Create(['a', 'b']) - | 'StageA' >> beam.Map(lambda x: x) - | 'StageB' >> beam.Map(lambda x: x) - | 'StageC' >> beam.Map(raise_error) - | 'StageD' >> beam.Map(lambda x: x)) - message = e_cm.exception.args[0] - self.assertIn('StageC', message) - self.assertNotIn('StageB', message) - - def test_error_traceback_includes_user_code(self): - def first(x): - return second(x) - - def second(x): - return third(x) - - def third(x): - raise ValueError('x') - - try: - with self.create_pipeline() as p: - p | beam.Create([0]) | beam.Map(first) # pylint: disable=expression-not-assigned - except Exception: # pylint: disable=broad-except - message = traceback.format_exc() - else: - raise AssertionError('expected exception not raised') - - self.assertIn('first', message) - self.assertIn('second', message) - self.assertIn('third', message) - - def test_no_subtransform_composite(self): - class First(beam.PTransform): - def expand(self, pcolls): - return pcolls[0] - - with self.create_pipeline() as p: - pcoll_a = p | 'a' >> beam.Create(['a']) - pcoll_b = p | 'b' >> beam.Create(['b']) - assert_that((pcoll_a, pcoll_b) | First(), equal_to(['a'])) - - @unittest.skip('Metrics not yet supported') - def test_metrics(self, check_gauge=True): - p = self.create_pipeline() - - counter = beam.metrics.Metrics.counter('ns', 'counter') - distribution = beam.metrics.Metrics.distribution('ns', 'distribution') - gauge = beam.metrics.Metrics.gauge('ns', 'gauge') - - pcoll = p | beam.Create(['a', 'zzz']) - # pylint: disable=expression-not-assigned - pcoll | 'count1' >> beam.FlatMap(lambda x: counter.inc()) - pcoll | 'count2' >> beam.FlatMap(lambda x: counter.inc(len(x))) - pcoll | 'dist' >> beam.FlatMap(lambda x: distribution.update(len(x))) - pcoll | 'gauge' >> beam.FlatMap(lambda x: gauge.set(3)) - - res = p.run() - res.wait_until_finish() - - t1, t2 = res.metrics().query(beam.metrics.MetricsFilter() - .with_name('counter'))['counters'] - self.assertEqual(t1.committed + t2.committed, 6) - - dist, = res.metrics().query(beam.metrics.MetricsFilter() - .with_name('distribution'))['distributions'] - self.assertEqual( - dist.committed.data, beam.metrics.cells.DistributionData(4, 2, 1, 3)) - self.assertEqual(dist.committed.mean, 2.0) - - if check_gauge: - gaug, = res.metrics().query(beam.metrics.MetricsFilter() - .with_name('gauge'))['gauges'] - self.assertEqual(gaug.committed.value, 3) - - def test_callbacks_with_exception(self): - elements_list = ['1', '2'] - - def raise_expetion(): - raise Exception('raise exception when calling callback') - - class FinalizebleDoFnWithException(beam.DoFn): - def process( - self, element, bundle_finalizer=beam.DoFn.BundleFinalizerParam): - bundle_finalizer.register(raise_expetion) - yield element - - with self.create_pipeline() as p: - res = ( - p - | beam.Create(elements_list) - | beam.ParDo(FinalizebleDoFnWithException())) - assert_that(res, equal_to(['1', '2'])) - - @unittest.skip('SDF not yet supported') - def test_register_finalizations(self): - event_recorder = EventRecorder(tempfile.gettempdir()) - - class FinalizableSplittableDoFn(beam.DoFn): - def process( - self, - element, - bundle_finalizer=beam.DoFn.BundleFinalizerParam, - restriction_tracker=beam.DoFn.RestrictionParam( - OffsetRangeProvider( - use_bounded_offset_range=True, checkpoint_only=True))): - # We use SDF to enforce finalization call happens by using - # self-initiated checkpoint. - if 'finalized' in event_recorder.events(): - restriction_tracker.try_claim( - restriction_tracker.current_restriction().start) - yield element - restriction_tracker.try_claim(element) - return - if restriction_tracker.try_claim( - restriction_tracker.current_restriction().start): - bundle_finalizer.register(lambda: event_recorder.record('finalized')) - # We sleep here instead of setting a resume time since the resume time - # doesn't need to be honored. - time.sleep(1) - restriction_tracker.defer_remainder() - - with self.create_pipeline() as p: - max_retries = 100 - res = ( - p - | beam.Create([max_retries]) - | beam.ParDo(FinalizableSplittableDoFn())) - assert_that(res, equal_to([max_retries])) - - event_recorder.cleanup() - - @unittest.skip('Combiners not yet supported') - def test_sdf_synthetic_source(self): - common_attrs = { - 'key_size': 1, - 'value_size': 1, - 'initial_splitting_num_bundles': 2, - 'initial_splitting_desired_bundle_size': 2, - 'sleep_per_input_record_sec': 0, - 'initial_splitting': 'const' - } - num_source_description = 5 - min_num_record = 10 - max_num_record = 20 - - # pylint: disable=unused-variable - source_descriptions = ([ - dict({'num_records': random.randint(min_num_record, max_num_record)}, - **common_attrs) for i in range(0, num_source_description) - ]) - total_num_records = 0 - for source in source_descriptions: - total_num_records += source['num_records'] - - with self.create_pipeline() as p: - res = ( - p - | beam.Create(source_descriptions) - | beam.ParDo(SyntheticSDFAsSource()) - | beam.combiners.Count.Globally()) - assert_that(res, equal_to([total_num_records])) - - def test_create_value_provider_pipeline_option(self): - # Verify that the runner can execute a pipeline when there are value - # provider pipeline options - # pylint: disable=unused-variable - class FooOptions(PipelineOptions): - @classmethod - def _add_argparse_args(cls, parser): - parser.add_value_provider_argument( - "--foo", help='a value provider argument', default="bar") - - RuntimeValueProvider.set_runtime_options({}) - - with self.create_pipeline() as p: - assert_that(p | beam.Create(['a', 'b']), equal_to(['a', 'b'])) - - def _test_pack_combiners(self, assert_using_counter_names): - counter = beam.metrics.Metrics.counter('ns', 'num_values') - - def min_with_counter(values): - counter.inc() - return min(values) - - def max_with_counter(values): - counter.inc() - return max(values) - - class PackableCombines(beam.PTransform): - def annotations(self): - return {python_urns.APPLY_COMBINER_PACKING: b''} - - def expand(self, pcoll): - assert_that( - pcoll | 'PackableMin' >> beam.CombineGlobally(min_with_counter), - equal_to([10]), - label='AssertMin') - assert_that( - pcoll | 'PackableMax' >> beam.CombineGlobally(max_with_counter), - equal_to([30]), - label='AssertMax') - - with self.create_pipeline() as p: - _ = p | beam.Create([10, 20, 30]) | PackableCombines() - - res = p.run() - res.wait_until_finish() - - packed_step_name_regex = ( - r'.*Packed.*PackableMin.*CombinePerKey.*PackableMax.*CombinePerKey.*' + - 'Pack.*') - - counters = res.metrics().query(beam.metrics.MetricsFilter())['counters'] - step_names = set(m.key.step for m in counters if m.key.step) - pipeline_options = p._options - if assert_using_counter_names: - if pipeline_options.view_as(StandardOptions).streaming: - self.assertFalse( - any(re.match(packed_step_name_regex, s) for s in step_names)) - else: - self.assertTrue( - any(re.match(packed_step_name_regex, s) for s in step_names)) + def test_no_subtransform_composite(self): + class First(beam.PTransform): + def expand(self, pcolls): + return pcolls[0] + + with self.create_pipeline() as p: + pcoll_a = p | "a" >> beam.Create(["a"]) + pcoll_b = p | "b" >> beam.Create(["b"]) + assert_that((pcoll_a, pcoll_b) | First(), equal_to(["a"])) - @unittest.skip('Combiners not yet supported') - def test_pack_combiners(self): - self._test_pack_combiners(assert_using_counter_names=True) + @unittest.skip("Metrics not yet supported") + def test_metrics(self, check_gauge=True): + p = self.create_pipeline() + + counter = beam.metrics.Metrics.counter("ns", "counter") + distribution = beam.metrics.Metrics.distribution("ns", "distribution") + gauge = beam.metrics.Metrics.gauge("ns", "gauge") + + pcoll = p | beam.Create(["a", "zzz"]) + # pylint: disable=expression-not-assigned + pcoll | "count1" >> beam.FlatMap(lambda x: counter.inc()) + pcoll | "count2" >> beam.FlatMap(lambda x: counter.inc(len(x))) + pcoll | "dist" >> beam.FlatMap(lambda x: distribution.update(len(x))) + pcoll | "gauge" >> beam.FlatMap(lambda x: gauge.set(3)) + + res = p.run() + res.wait_until_finish() + + t1, t2 = res.metrics().query(beam.metrics.MetricsFilter().with_name("counter"))[ + "counters" + ] + self.assertEqual(t1.committed + t2.committed, 6) + + (dist,) = res.metrics().query( + beam.metrics.MetricsFilter().with_name("distribution") + )["distributions"] + self.assertEqual( + dist.committed.data, beam.metrics.cells.DistributionData(4, 2, 1, 3) + ) + self.assertEqual(dist.committed.mean, 2.0) + + if check_gauge: + (gaug,) = res.metrics().query( + beam.metrics.MetricsFilter().with_name("gauge") + )["gauges"] + self.assertEqual(gaug.committed.value, 3) + + def test_callbacks_with_exception(self): + elements_list = ["1", "2"] + + def raise_expetion(): + raise Exception("raise exception when calling callback") + + class FinalizebleDoFnWithException(beam.DoFn): + def process(self, element, bundle_finalizer=beam.DoFn.BundleFinalizerParam): + bundle_finalizer.register(raise_expetion) + yield element + + with self.create_pipeline() as p: + res = ( + p + | beam.Create(elements_list) + | beam.ParDo(FinalizebleDoFnWithException()) + ) + assert_that(res, equal_to(["1", "2"])) + + @unittest.skip("SDF not yet supported") + def test_register_finalizations(self): + event_recorder = EventRecorder(tempfile.gettempdir()) + + class FinalizableSplittableDoFn(beam.DoFn): + def process( + self, + element, + bundle_finalizer=beam.DoFn.BundleFinalizerParam, + restriction_tracker=beam.DoFn.RestrictionParam( + OffsetRangeProvider( + use_bounded_offset_range=True, checkpoint_only=True + ) + ), + ): + # We use SDF to enforce finalization call happens by using + # self-initiated checkpoint. + if "finalized" in event_recorder.events(): + restriction_tracker.try_claim( + restriction_tracker.current_restriction().start + ) + yield element + restriction_tracker.try_claim(element) + return + if restriction_tracker.try_claim( + restriction_tracker.current_restriction().start + ): + bundle_finalizer.register( + lambda: event_recorder.record("finalized") + ) + # We sleep here instead of setting a resume time since + # the resume time doesn't need to be honored. + time.sleep(1) + restriction_tracker.defer_remainder() + + with self.create_pipeline() as p: + max_retries = 100 + res = ( + p | beam.Create([max_retries]) | beam.ParDo(FinalizableSplittableDoFn()) + ) + assert_that(res, equal_to([max_retries])) + + event_recorder.cleanup() + + @unittest.skip("Combiners not yet supported") + def test_sdf_synthetic_source(self): + common_attrs = { + "key_size": 1, + "value_size": 1, + "initial_splitting_num_bundles": 2, + "initial_splitting_desired_bundle_size": 2, + "sleep_per_input_record_sec": 0, + "initial_splitting": "const", + } + num_source_description = 5 + min_num_record = 10 + max_num_record = 20 + + # pylint: disable=unused-variable + source_descriptions = [ + dict( + {"num_records": random.randint(min_num_record, max_num_record)}, + **common_attrs, + ) + for i in range(0, num_source_description) + ] + total_num_records = 0 + for source in source_descriptions: + total_num_records += source["num_records"] + + with self.create_pipeline() as p: + res = ( + p + | beam.Create(source_descriptions) + | beam.ParDo(SyntheticSDFAsSource()) + | beam.combiners.Count.Globally() + ) + assert_that(res, equal_to([total_num_records])) + + def test_create_value_provider_pipeline_option(self): + # Verify that the runner can execute a pipeline when there are value + # provider pipeline options + # pylint: disable=unused-variable + class FooOptions(PipelineOptions): + @classmethod + def _add_argparse_args(cls, parser): + parser.add_value_provider_argument( + "--foo", help="a value provider argument", default="bar" + ) + + RuntimeValueProvider.set_runtime_options({}) + + with self.create_pipeline() as p: + assert_that(p | beam.Create(["a", "b"]), equal_to(["a", "b"])) + + def _test_pack_combiners(self, assert_using_counter_names): + counter = beam.metrics.Metrics.counter("ns", "num_values") + + def min_with_counter(values): + counter.inc() + return min(values) + + def max_with_counter(values): + counter.inc() + return max(values) + + class PackableCombines(beam.PTransform): + def annotations(self): + return {python_urns.APPLY_COMBINER_PACKING: b""} + + def expand(self, pcoll): + assert_that( + pcoll | "PackableMin" >> beam.CombineGlobally(min_with_counter), + equal_to([10]), + label="AssertMin", + ) + assert_that( + pcoll | "PackableMax" >> beam.CombineGlobally(max_with_counter), + equal_to([30]), + label="AssertMax", + ) + + with self.create_pipeline() as p: + _ = p | beam.Create([10, 20, 30]) | PackableCombines() + + res = p.run() + res.wait_until_finish() + + packed_step_name_regex = ( + r".*Packed.*PackableMin.*CombinePerKey.*PackableMax.*CombinePerKey.*" + + "Pack.*" + ) + + counters = res.metrics().query(beam.metrics.MetricsFilter())["counters"] + step_names = set(m.key.step for m in counters if m.key.step) + pipeline_options = p._options + if assert_using_counter_names: + if pipeline_options.view_as(StandardOptions).streaming: + self.assertFalse( + any(re.match(packed_step_name_regex, s) for s in step_names) + ) + else: + self.assertTrue( + any(re.match(packed_step_name_regex, s) for s in step_names) + ) + + @unittest.skip("Combiners not yet supported") + def test_pack_combiners(self): + self._test_pack_combiners(assert_using_counter_names=True) # These tests are kept in a separate group so that they are @@ -1114,824 +1215,873 @@ def test_pack_combiners(self): # it makes the probability of sampling far too small # upon repeating bundle processing due to unncessarily incrementing # the sampling counter. -@unittest.skip('Metrics not yet supported.') +@unittest.skip("Metrics not yet supported.") class RayRunnerMetricsTest(unittest.TestCase): - @classmethod - def setUpClass(cls) -> None: - if not ray.is_initialized(): - ray.init(local_mode=True) - - def assert_has_counter( - self, mon_infos, urn, labels, value=None, ge_value=None): - found = 0 - matches = [] - for mi in mon_infos: - if has_urn_and_labels(mi, urn, labels): - extracted_value = monitoring_infos.extract_counter_value(mi) - if ge_value is not None: - if extracted_value >= ge_value: - found = found + 1 - elif value is not None: - if extracted_value == value: - found = found + 1 - else: - found = found + 1 - ge_value_str = {'ge_value': ge_value} if ge_value else '' - value_str = {'value': value} if value else '' - self.assertEqual( - 1, - found, - "Found (%s, %s) Expected only 1 monitoring_info for %s." % ( + @classmethod + def setUpClass(cls) -> None: + if not ray.is_initialized(): + ray.init(local_mode=True) + + def assert_has_counter(self, mon_infos, urn, labels, value=None, ge_value=None): + found = 0 + matches = [] + for mi in mon_infos: + if has_urn_and_labels(mi, urn, labels): + extracted_value = monitoring_infos.extract_counter_value(mi) + if ge_value is not None: + if extracted_value >= ge_value: + found = found + 1 + elif value is not None: + if extracted_value == value: + found = found + 1 + else: + found = found + 1 + ge_value_str = {"ge_value": ge_value} if ge_value else "" + value_str = {"value": value} if value else "" + self.assertEqual( + 1, found, - matches, - (urn, labels, value_str, ge_value_str), - )) - - def assert_has_distribution( - self, mon_infos, urn, labels, sum=None, count=None, min=None, max=None): - # TODO(ajamato): Consider adding a matcher framework - sum = _matcher_or_equal_to(sum) - count = _matcher_or_equal_to(count) - min = _matcher_or_equal_to(min) - max = _matcher_or_equal_to(max) - found = 0 - description = StringDescription() - for mi in mon_infos: - if has_urn_and_labels(mi, urn, labels): - (extracted_count, extracted_sum, extracted_min, - extracted_max) = monitoring_infos.extract_distribution(mi) - increment = 1 - if sum is not None: - description.append_text(' sum: ') - sum.describe_to(description) - if not sum.matches(extracted_sum): - increment = 0 - if count is not None: - description.append_text(' count: ') - count.describe_to(description) - if not count.matches(extracted_count): - increment = 0 - if min is not None: - description.append_text(' min: ') - min.describe_to(description) - if not min.matches(extracted_min): - increment = 0 - if max is not None: - description.append_text(' max: ') - max.describe_to(description) - if not max.matches(extracted_max): - increment = 0 - found += increment - self.assertEqual( - 1, - found, - "Found (%s) Expected only 1 monitoring_info for %s." % ( + "Found (%s, %s) Expected only 1 monitoring_info for %s." + % ( + found, + matches, + (urn, labels, value_str, ge_value_str), + ), + ) + + def assert_has_distribution( + self, mon_infos, urn, labels, sum=None, count=None, min=None, max=None + ): + # TODO(ajamato): Consider adding a matcher framework + sum = _matcher_or_equal_to(sum) + count = _matcher_or_equal_to(count) + min = _matcher_or_equal_to(min) + max = _matcher_or_equal_to(max) + found = 0 + description = StringDescription() + for mi in mon_infos: + if has_urn_and_labels(mi, urn, labels): + ( + extracted_count, + extracted_sum, + extracted_min, + extracted_max, + ) = monitoring_infos.extract_distribution(mi) + increment = 1 + if sum is not None: + description.append_text(" sum: ") + sum.describe_to(description) + if not sum.matches(extracted_sum): + increment = 0 + if count is not None: + description.append_text(" count: ") + count.describe_to(description) + if not count.matches(extracted_count): + increment = 0 + if min is not None: + description.append_text(" min: ") + min.describe_to(description) + if not min.matches(extracted_min): + increment = 0 + if max is not None: + description.append_text(" max: ") + max.describe_to(description) + if not max.matches(extracted_max): + increment = 0 + found += increment + self.assertEqual( + 1, found, - (urn, labels, str(description)), - )) - - def create_pipeline(self): - return beam.Pipeline(runner=ray_beam_runner.portability.ray_fn_runner.RayFnApiRunner()) - - def test_element_count_metrics(self): - class GenerateTwoOutputs(beam.DoFn): - def process(self, element): - yield str(element) + '1' - yield beam.pvalue.TaggedOutput('SecondOutput', str(element) + '2') - yield beam.pvalue.TaggedOutput('SecondOutput', str(element) + '2') - yield beam.pvalue.TaggedOutput('ThirdOutput', str(element) + '3') - - class PassThrough(beam.DoFn): - def process(self, element): - yield element - - p = self.create_pipeline() - - # Produce enough elements to make sure byte sampling occurs. - num_source_elems = 100 - pcoll = p | beam.Create(['a%d' % i for i in range(num_source_elems)], - reshuffle=False) - - # pylint: disable=expression-not-assigned - pardo = ( - 'StepThatDoesTwoOutputs' >> beam.ParDo( - GenerateTwoOutputs()).with_outputs( - 'SecondOutput', 'ThirdOutput', main='FirstAndMainOutput')) - - # Actually feed pcollection to pardo - second_output, third_output, first_output = (pcoll | pardo) - - # consume some of elements - merged = ((first_output, second_output, third_output) | beam.Flatten()) - merged | ('PassThrough') >> beam.ParDo(PassThrough()) - second_output | ('PassThrough2') >> beam.ParDo(PassThrough()) - - res = p.run() - res.wait_until_finish() - - result_metrics = res.monitoring_metrics() - - counters = result_metrics.monitoring_infos() - # All element count and byte count metrics must have a PCOLLECTION_LABEL. - self.assertFalse([ - x for x in counters if x.urn in [ - monitoring_infos.ELEMENT_COUNT_URN, - monitoring_infos.SAMPLED_BYTE_SIZE_URN - ] and monitoring_infos.PCOLLECTION_LABEL not in x.labels - ]) - try: - labels = { - monitoring_infos.PCOLLECTION_LABEL: 'ref_PCollection_PCollection_1' - } - self.assert_has_counter( - counters, monitoring_infos.ELEMENT_COUNT_URN, labels, 1) - - # Create output. - labels = { - monitoring_infos.PCOLLECTION_LABEL: 'ref_PCollection_PCollection_3' - } - self.assert_has_counter( - counters, - monitoring_infos.ELEMENT_COUNT_URN, - labels, - num_source_elems) - self.assert_has_distribution( - counters, - monitoring_infos.SAMPLED_BYTE_SIZE_URN, - labels, - min=hamcrest.greater_than(0), - max=hamcrest.greater_than(0), - sum=hamcrest.greater_than(0), - count=hamcrest.greater_than(0)) - - # GenerateTwoOutputs, main output. - labels = { - monitoring_infos.PCOLLECTION_LABEL: 'ref_PCollection_PCollection_4' - } - self.assert_has_counter( - counters, - monitoring_infos.ELEMENT_COUNT_URN, - labels, - num_source_elems) - self.assert_has_distribution( - counters, - monitoring_infos.SAMPLED_BYTE_SIZE_URN, - labels, - min=hamcrest.greater_than(0), - max=hamcrest.greater_than(0), - sum=hamcrest.greater_than(0), - count=hamcrest.greater_than(0)) - - # GenerateTwoOutputs, "SecondOutput" output. - labels = { - monitoring_infos.PCOLLECTION_LABEL: 'ref_PCollection_PCollection_5' - } - self.assert_has_counter( - counters, - monitoring_infos.ELEMENT_COUNT_URN, - labels, - 2 * num_source_elems) - self.assert_has_distribution( - counters, - monitoring_infos.SAMPLED_BYTE_SIZE_URN, - labels, - min=hamcrest.greater_than(0), - max=hamcrest.greater_than(0), - sum=hamcrest.greater_than(0), - count=hamcrest.greater_than(0)) - - # GenerateTwoOutputs, "ThirdOutput" output. - labels = { - monitoring_infos.PCOLLECTION_LABEL: 'ref_PCollection_PCollection_6' - } - self.assert_has_counter( - counters, - monitoring_infos.ELEMENT_COUNT_URN, - labels, - num_source_elems) - self.assert_has_distribution( - counters, - monitoring_infos.SAMPLED_BYTE_SIZE_URN, - labels, - min=hamcrest.greater_than(0), - max=hamcrest.greater_than(0), - sum=hamcrest.greater_than(0), - count=hamcrest.greater_than(0)) - - # Skipping other pcollections due to non-deterministic naming for multiple - # outputs. - # Flatten/Read, main output. - labels = { - monitoring_infos.PCOLLECTION_LABEL: 'ref_PCollection_PCollection_7' - } - self.assert_has_counter( - counters, - monitoring_infos.ELEMENT_COUNT_URN, - labels, - 4 * num_source_elems) - self.assert_has_distribution( - counters, - monitoring_infos.SAMPLED_BYTE_SIZE_URN, - labels, - min=hamcrest.greater_than(0), - max=hamcrest.greater_than(0), - sum=hamcrest.greater_than(0), - count=hamcrest.greater_than(0)) - - # PassThrough, main output - labels = { - monitoring_infos.PCOLLECTION_LABEL: 'ref_PCollection_PCollection_8' - } - self.assert_has_counter( - counters, - monitoring_infos.ELEMENT_COUNT_URN, - labels, - 4 * num_source_elems) - self.assert_has_distribution( - counters, - monitoring_infos.SAMPLED_BYTE_SIZE_URN, - labels, - min=hamcrest.greater_than(0), - max=hamcrest.greater_than(0), - sum=hamcrest.greater_than(0), - count=hamcrest.greater_than(0)) - - # PassThrough2, main output - labels = { - monitoring_infos.PCOLLECTION_LABEL: 'ref_PCollection_PCollection_9' - } - self.assert_has_counter( - counters, - monitoring_infos.ELEMENT_COUNT_URN, - labels, - num_source_elems) - self.assert_has_distribution( - counters, - monitoring_infos.SAMPLED_BYTE_SIZE_URN, - labels, - min=hamcrest.greater_than(0), - max=hamcrest.greater_than(0), - sum=hamcrest.greater_than(0), - count=hamcrest.greater_than(0)) - except: - print(res._monitoring_infos_by_stage) - raise - - def test_non_user_metrics(self): - p = self.create_pipeline() - - pcoll = p | beam.Create(['a', 'zzz']) - # pylint: disable=expression-not-assigned - pcoll | 'MyStep' >> beam.FlatMap(lambda x: None) - res = p.run() - res.wait_until_finish() - - result_metrics = res.monitoring_metrics() - all_metrics_via_montoring_infos = result_metrics.query() - - def assert_counter_exists(metrics, namespace, name, step): - found = 0 - metric_key = MetricKey(step, MetricName(namespace, name)) - for m in metrics['counters']: - if m.key == metric_key: - found = found + 1 - self.assertEqual( - 1, found, "Did not find exactly 1 metric for %s." % metric_key) - - urns = [ - monitoring_infos.START_BUNDLE_MSECS_URN, - monitoring_infos.PROCESS_BUNDLE_MSECS_URN, - monitoring_infos.FINISH_BUNDLE_MSECS_URN, - monitoring_infos.TOTAL_MSECS_URN, - ] - for urn in urns: - split = urn.split(':') - namespace = split[0] - name = ':'.join(split[1:]) - assert_counter_exists( - all_metrics_via_montoring_infos, - namespace, - name, - step='Create/Impulse') - assert_counter_exists( - all_metrics_via_montoring_infos, namespace, name, step='MyStep') - - # Due to somewhat non-deterministic nature of state sampling and sleep, - # this test is flaky when state duration is low. - # Since increasing state duration significantly would also slow down - # the test suite, we are retrying twice on failure as a mitigation. - def test_progress_metrics(self): - p = self.create_pipeline() - - _ = ( - p - | beam.Create([0, 0, 0, 5e-3 * DEFAULT_SAMPLING_PERIOD_MS], - reshuffle=False) - | beam.Map(time.sleep) - | beam.Map(lambda x: ('key', x)) - | beam.GroupByKey() - | 'm_out' >> beam.FlatMap( - lambda x: [ - 1, - 2, - 3, - 4, - 5, - beam.pvalue.TaggedOutput('once', x), - beam.pvalue.TaggedOutput('twice', x), - beam.pvalue.TaggedOutput('twice', x) - ])) - - res = p.run() - res.wait_until_finish() - - def has_mi_for_ptransform(mon_infos, ptransform): - for mi in mon_infos: - if ptransform in mi.labels[monitoring_infos.PTRANSFORM_LABEL]: - return True - return False - - try: - # Test the new MonitoringInfo monitoring format. - self.assertEqual(3, len(res._monitoring_infos_by_stage)) - pregbk_mis, postgbk_mis = [ - mi for stage, mi in res._monitoring_infos_by_stage.items() if stage] - - if not has_mi_for_ptransform(pregbk_mis, 'Create/Map(decode)'): - # The monitoring infos above are actually unordered. Swap. - pregbk_mis, postgbk_mis = postgbk_mis, pregbk_mis - - # pregbk monitoring infos - labels = { - monitoring_infos.PCOLLECTION_LABEL: 'ref_PCollection_PCollection_3' - } - self.assert_has_counter( - pregbk_mis, monitoring_infos.ELEMENT_COUNT_URN, labels, value=4) - self.assert_has_distribution( - pregbk_mis, monitoring_infos.SAMPLED_BYTE_SIZE_URN, labels) - - labels = { - monitoring_infos.PCOLLECTION_LABEL: 'ref_PCollection_PCollection_4' - } - self.assert_has_counter( - pregbk_mis, monitoring_infos.ELEMENT_COUNT_URN, labels, value=4) - self.assert_has_distribution( - pregbk_mis, monitoring_infos.SAMPLED_BYTE_SIZE_URN, labels) - - labels = {monitoring_infos.PTRANSFORM_LABEL: 'Map(sleep)'} - self.assert_has_counter( - pregbk_mis, - monitoring_infos.TOTAL_MSECS_URN, - labels, - ge_value=4 * DEFAULT_SAMPLING_PERIOD_MS) - - # postgbk monitoring infos - labels = { - monitoring_infos.PCOLLECTION_LABEL: 'ref_PCollection_PCollection_6' - } - self.assert_has_counter( - postgbk_mis, monitoring_infos.ELEMENT_COUNT_URN, labels, value=1) - self.assert_has_distribution( - postgbk_mis, monitoring_infos.SAMPLED_BYTE_SIZE_URN, labels) - - labels = { - monitoring_infos.PCOLLECTION_LABEL: 'ref_PCollection_PCollection_7' - } - self.assert_has_counter( - postgbk_mis, monitoring_infos.ELEMENT_COUNT_URN, labels, value=5) - self.assert_has_distribution( - postgbk_mis, monitoring_infos.SAMPLED_BYTE_SIZE_URN, labels) - except: - print(res._monitoring_infos_by_stage) - raise - - -@unittest.skip('Runner-initiated splitting not yet supported') -class RayRunnerSplitTest(unittest.TestCase): - @classmethod - def setUpClass(cls) -> None: - if not ray.is_initialized(): - ray.init(local_mode=True) - - def create_pipeline(self, is_drain=False): - return beam.Pipeline(runner=ray_beam_runner.portability.ray_fn_runner.RayFnApiRunner()) - - def test_checkpoint(self): - # This split manager will get re-invoked on each smaller split, - # so N times for N elements. - element_counter = ElementCounter() - - def split_manager(num_elements): - # Send at least one element so it can make forward progress. - element_counter.reset() - breakpoint = element_counter.set_breakpoint(1) - # Cede control back to the runner so data can be sent. - yield - breakpoint.wait() - # Split as close to current as possible. - split_result = yield 0.0 - # Verify we split at exactly the first element. - self.verify_channel_split(split_result, 0, 1) - # Continue processing. - breakpoint.clear() - - self.run_split_pipeline(split_manager, list('abc'), element_counter) - - def test_split_half(self): - total_num_elements = 25 - seen_bundle_sizes = [] - element_counter = ElementCounter() - - def split_manager(num_elements): - seen_bundle_sizes.append(num_elements) - if num_elements == total_num_elements: - element_counter.reset() - breakpoint = element_counter.set_breakpoint(5) - yield - breakpoint.wait() - # Split the remainder (20, then 10, elements) in half. - split1 = yield 0.5 - self.verify_channel_split(split1, 14, 15) # remainder is 15 to end - split2 = yield 0.5 - self.verify_channel_split(split2, 9, 10) # remainder is 10 to end - breakpoint.clear() - - self.run_split_pipeline( - split_manager, range(total_num_elements), element_counter) - self.assertEqual([25, 15], seen_bundle_sizes) - - def run_split_pipeline(self, split_manager, elements, element_counter=None): - with fn_runner.split_manager('Identity', split_manager): - with self.create_pipeline() as p: - res = ( + "Found (%s) Expected only 1 monitoring_info for %s." + % ( + found, + (urn, labels, str(description)), + ), + ) + + def create_pipeline(self): + return beam.Pipeline( + runner=ray_beam_runner.portability.ray_fn_runner.RayFnApiRunner() + ) + + def test_element_count_metrics(self): + class GenerateTwoOutputs(beam.DoFn): + def process(self, element): + yield str(element) + "1" + yield beam.pvalue.TaggedOutput("SecondOutput", str(element) + "2") + yield beam.pvalue.TaggedOutput("SecondOutput", str(element) + "2") + yield beam.pvalue.TaggedOutput("ThirdOutput", str(element) + "3") + + class PassThrough(beam.DoFn): + def process(self, element): + yield element + + p = self.create_pipeline() + + # Produce enough elements to make sure byte sampling occurs. + num_source_elems = 100 + pcoll = p | beam.Create( + ["a%d" % i for i in range(num_source_elems)], reshuffle=False + ) + + # pylint: disable=expression-not-assigned + pardo = "StepThatDoesTwoOutputs" >> beam.ParDo( + GenerateTwoOutputs() + ).with_outputs("SecondOutput", "ThirdOutput", main="FirstAndMainOutput") + + # Actually feed pcollection to pardo + second_output, third_output, first_output = pcoll | pardo + + # consume some of elements + merged = (first_output, second_output, third_output) | beam.Flatten() + merged | ("PassThrough") >> beam.ParDo(PassThrough()) + second_output | ("PassThrough2") >> beam.ParDo(PassThrough()) + + res = p.run() + res.wait_until_finish() + + result_metrics = res.monitoring_metrics() + + counters = result_metrics.monitoring_infos() + # All element count and byte count metrics must have a PCOLLECTION_LABEL. + self.assertFalse( + [ + x + for x in counters + if x.urn + in [ + monitoring_infos.ELEMENT_COUNT_URN, + monitoring_infos.SAMPLED_BYTE_SIZE_URN, + ] + and monitoring_infos.PCOLLECTION_LABEL not in x.labels + ] + ) + try: + labels = { + monitoring_infos.PCOLLECTION_LABEL: "ref_PCollection_PCollection_1" + } + self.assert_has_counter( + counters, monitoring_infos.ELEMENT_COUNT_URN, labels, 1 + ) + + # Create output. + labels = { + monitoring_infos.PCOLLECTION_LABEL: "ref_PCollection_PCollection_3" + } + self.assert_has_counter( + counters, monitoring_infos.ELEMENT_COUNT_URN, labels, num_source_elems + ) + self.assert_has_distribution( + counters, + monitoring_infos.SAMPLED_BYTE_SIZE_URN, + labels, + min=hamcrest.greater_than(0), + max=hamcrest.greater_than(0), + sum=hamcrest.greater_than(0), + count=hamcrest.greater_than(0), + ) + + # GenerateTwoOutputs, main output. + labels = { + monitoring_infos.PCOLLECTION_LABEL: "ref_PCollection_PCollection_4" + } + self.assert_has_counter( + counters, monitoring_infos.ELEMENT_COUNT_URN, labels, num_source_elems + ) + self.assert_has_distribution( + counters, + monitoring_infos.SAMPLED_BYTE_SIZE_URN, + labels, + min=hamcrest.greater_than(0), + max=hamcrest.greater_than(0), + sum=hamcrest.greater_than(0), + count=hamcrest.greater_than(0), + ) + + # GenerateTwoOutputs, "SecondOutput" output. + labels = { + monitoring_infos.PCOLLECTION_LABEL: "ref_PCollection_PCollection_5" + } + self.assert_has_counter( + counters, + monitoring_infos.ELEMENT_COUNT_URN, + labels, + 2 * num_source_elems, + ) + self.assert_has_distribution( + counters, + monitoring_infos.SAMPLED_BYTE_SIZE_URN, + labels, + min=hamcrest.greater_than(0), + max=hamcrest.greater_than(0), + sum=hamcrest.greater_than(0), + count=hamcrest.greater_than(0), + ) + + # GenerateTwoOutputs, "ThirdOutput" output. + labels = { + monitoring_infos.PCOLLECTION_LABEL: "ref_PCollection_PCollection_6" + } + self.assert_has_counter( + counters, monitoring_infos.ELEMENT_COUNT_URN, labels, num_source_elems + ) + self.assert_has_distribution( + counters, + monitoring_infos.SAMPLED_BYTE_SIZE_URN, + labels, + min=hamcrest.greater_than(0), + max=hamcrest.greater_than(0), + sum=hamcrest.greater_than(0), + count=hamcrest.greater_than(0), + ) + + # Skipping other pcollections due to non-deterministic naming for multiple + # outputs. + # Flatten/Read, main output. + labels = { + monitoring_infos.PCOLLECTION_LABEL: "ref_PCollection_PCollection_7" + } + self.assert_has_counter( + counters, + monitoring_infos.ELEMENT_COUNT_URN, + labels, + 4 * num_source_elems, + ) + self.assert_has_distribution( + counters, + monitoring_infos.SAMPLED_BYTE_SIZE_URN, + labels, + min=hamcrest.greater_than(0), + max=hamcrest.greater_than(0), + sum=hamcrest.greater_than(0), + count=hamcrest.greater_than(0), + ) + + # PassThrough, main output + labels = { + monitoring_infos.PCOLLECTION_LABEL: "ref_PCollection_PCollection_8" + } + self.assert_has_counter( + counters, + monitoring_infos.ELEMENT_COUNT_URN, + labels, + 4 * num_source_elems, + ) + self.assert_has_distribution( + counters, + monitoring_infos.SAMPLED_BYTE_SIZE_URN, + labels, + min=hamcrest.greater_than(0), + max=hamcrest.greater_than(0), + sum=hamcrest.greater_than(0), + count=hamcrest.greater_than(0), + ) + + # PassThrough2, main output + labels = { + monitoring_infos.PCOLLECTION_LABEL: "ref_PCollection_PCollection_9" + } + self.assert_has_counter( + counters, monitoring_infos.ELEMENT_COUNT_URN, labels, num_source_elems + ) + self.assert_has_distribution( + counters, + monitoring_infos.SAMPLED_BYTE_SIZE_URN, + labels, + min=hamcrest.greater_than(0), + max=hamcrest.greater_than(0), + sum=hamcrest.greater_than(0), + count=hamcrest.greater_than(0), + ) + except Exception: + raise + + def test_non_user_metrics(self): + p = self.create_pipeline() + + pcoll = p | beam.Create(["a", "zzz"]) + # pylint: disable=expression-not-assigned + pcoll | "MyStep" >> beam.FlatMap(lambda x: None) + res = p.run() + res.wait_until_finish() + + result_metrics = res.monitoring_metrics() + all_metrics_via_montoring_infos = result_metrics.query() + + def assert_counter_exists(metrics, namespace, name, step): + found = 0 + metric_key = MetricKey(step, MetricName(namespace, name)) + for m in metrics["counters"]: + if m.key == metric_key: + found = found + 1 + self.assertEqual( + 1, found, "Did not find exactly 1 metric for %s." % metric_key + ) + + urns = [ + monitoring_infos.START_BUNDLE_MSECS_URN, + monitoring_infos.PROCESS_BUNDLE_MSECS_URN, + monitoring_infos.FINISH_BUNDLE_MSECS_URN, + monitoring_infos.TOTAL_MSECS_URN, + ] + for urn in urns: + split = urn.split(":") + namespace = split[0] + name = ":".join(split[1:]) + assert_counter_exists( + all_metrics_via_montoring_infos, namespace, name, step="Create/Impulse" + ) + assert_counter_exists( + all_metrics_via_montoring_infos, namespace, name, step="MyStep" + ) + + # Due to somewhat non-deterministic nature of state sampling and sleep, + # this test is flaky when state duration is low. + # Since increasing state duration significantly would also slow down + # the test suite, we are retrying twice on failure as a mitigation. + def test_progress_metrics(self): + p = self.create_pipeline() + + _ = ( p - | beam.Create(elements) - | beam.Reshuffle() - | 'Identity' >> beam.Map(lambda x: x) - | beam.Map(lambda x: element_counter.increment() or x)) - assert_that(res, equal_to(elements)) - - def run_sdf_checkpoint(self, is_drain=False): - element_counter = ElementCounter() - - def split_manager(num_elements): - if num_elements > 0: - element_counter.reset() - breakpoint = element_counter.set_breakpoint(1) - yield - breakpoint.wait() - yield 0 - breakpoint.clear() - - # Everything should be perfectly split. - - elements = [2, 3] - expected_groups = [[(2, 0)], [(2, 1)], [(3, 0)], [(3, 1)], [(3, 2)]] - self.run_sdf_split_pipeline( - split_manager, - elements, - element_counter, - expected_groups, - is_drain=is_drain) - - def run_sdf_split_half(self, is_drain=False): - element_counter = ElementCounter() - is_first_bundle = True - - def split_manager(num_elements): - nonlocal is_first_bundle - if is_first_bundle and num_elements > 0: - is_first_bundle = False - breakpoint = element_counter.set_breakpoint(1) - yield - breakpoint.wait() - split1 = yield 0.5 - split2 = yield 0.5 - split3 = yield 0.5 - self.verify_channel_split(split1, 0, 1) - self.verify_channel_split(split2, -1, 1) - self.verify_channel_split(split3, -1, 1) - breakpoint.clear() - - elements = [4, 4] - expected_groups = [[(4, 0)], [(4, 1)], [(4, 2), (4, 3)], [(4, 0), (4, 1), - (4, 2), (4, 3)]] - - self.run_sdf_split_pipeline( + | beam.Create([0, 0, 0, 5e-3 * DEFAULT_SAMPLING_PERIOD_MS], reshuffle=False) + | beam.Map(time.sleep) + | beam.Map(lambda x: ("key", x)) + | beam.GroupByKey() + | "m_out" + >> beam.FlatMap( + lambda x: [ + 1, + 2, + 3, + 4, + 5, + beam.pvalue.TaggedOutput("once", x), + beam.pvalue.TaggedOutput("twice", x), + beam.pvalue.TaggedOutput("twice", x), + ] + ) + ) + + res = p.run() + res.wait_until_finish() + + def has_mi_for_ptransform(mon_infos, ptransform): + for mi in mon_infos: + if ptransform in mi.labels[monitoring_infos.PTRANSFORM_LABEL]: + return True + return False + + try: + # Test the new MonitoringInfo monitoring format. + self.assertEqual(3, len(res._monitoring_infos_by_stage)) + pregbk_mis, postgbk_mis = [ + mi for stage, mi in res._monitoring_infos_by_stage.items() if stage + ] + + if not has_mi_for_ptransform(pregbk_mis, "Create/Map(decode)"): + # The monitoring infos above are actually unordered. Swap. + pregbk_mis, postgbk_mis = postgbk_mis, pregbk_mis + + # pregbk monitoring infos + labels = { + monitoring_infos.PCOLLECTION_LABEL: "ref_PCollection_PCollection_3" + } + self.assert_has_counter( + pregbk_mis, monitoring_infos.ELEMENT_COUNT_URN, labels, value=4 + ) + self.assert_has_distribution( + pregbk_mis, monitoring_infos.SAMPLED_BYTE_SIZE_URN, labels + ) + + labels = { + monitoring_infos.PCOLLECTION_LABEL: "ref_PCollection_PCollection_4" + } + self.assert_has_counter( + pregbk_mis, monitoring_infos.ELEMENT_COUNT_URN, labels, value=4 + ) + self.assert_has_distribution( + pregbk_mis, monitoring_infos.SAMPLED_BYTE_SIZE_URN, labels + ) + + labels = {monitoring_infos.PTRANSFORM_LABEL: "Map(sleep)"} + self.assert_has_counter( + pregbk_mis, + monitoring_infos.TOTAL_MSECS_URN, + labels, + ge_value=4 * DEFAULT_SAMPLING_PERIOD_MS, + ) + + # postgbk monitoring infos + labels = { + monitoring_infos.PCOLLECTION_LABEL: "ref_PCollection_PCollection_6" + } + self.assert_has_counter( + postgbk_mis, monitoring_infos.ELEMENT_COUNT_URN, labels, value=1 + ) + self.assert_has_distribution( + postgbk_mis, monitoring_infos.SAMPLED_BYTE_SIZE_URN, labels + ) + + labels = { + monitoring_infos.PCOLLECTION_LABEL: "ref_PCollection_PCollection_7" + } + self.assert_has_counter( + postgbk_mis, monitoring_infos.ELEMENT_COUNT_URN, labels, value=5 + ) + self.assert_has_distribution( + postgbk_mis, monitoring_infos.SAMPLED_BYTE_SIZE_URN, labels + ) + except Exception: + raise + + +@unittest.skip("Runner-initiated splitting not yet supported") +class RayRunnerSplitTest(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + if not ray.is_initialized(): + ray.init(local_mode=True) + + def create_pipeline(self, is_drain=False): + return beam.Pipeline( + runner=ray_beam_runner.portability.ray_fn_runner.RayFnApiRunner() + ) + + def test_checkpoint(self): + # This split manager will get re-invoked on each smaller split, + # so N times for N elements. + element_counter = ElementCounter() + + def split_manager(num_elements): + # Send at least one element so it can make forward progress. + element_counter.reset() + breakpoint = element_counter.set_breakpoint(1) + # Cede control back to the runner so data can be sent. + yield + breakpoint.wait() + # Split as close to current as possible. + split_result = yield 0.0 + # Verify we split at exactly the first element. + self.verify_channel_split(split_result, 0, 1) + # Continue processing. + breakpoint.clear() + + self.run_split_pipeline(split_manager, list("abc"), element_counter) + + def test_split_half(self): + total_num_elements = 25 + seen_bundle_sizes = [] + element_counter = ElementCounter() + + def split_manager(num_elements): + seen_bundle_sizes.append(num_elements) + if num_elements == total_num_elements: + element_counter.reset() + breakpoint = element_counter.set_breakpoint(5) + yield + breakpoint.wait() + # Split the remainder (20, then 10, elements) in half. + split1 = yield 0.5 + self.verify_channel_split(split1, 14, 15) # remainder is 15 to end + split2 = yield 0.5 + self.verify_channel_split(split2, 9, 10) # remainder is 10 to end + breakpoint.clear() + + self.run_split_pipeline( + split_manager, range(total_num_elements), element_counter + ) + self.assertEqual([25, 15], seen_bundle_sizes) + + def run_split_pipeline(self, split_manager, elements, element_counter=None): + with fn_runner.split_manager("Identity", split_manager): + with self.create_pipeline() as p: + res = ( + p + | beam.Create(elements) + | beam.Reshuffle() + | "Identity" >> beam.Map(lambda x: x) + | beam.Map(lambda x: element_counter.increment() or x) + ) + assert_that(res, equal_to(elements)) + + def run_sdf_checkpoint(self, is_drain=False): + element_counter = ElementCounter() + + def split_manager(num_elements): + if num_elements > 0: + element_counter.reset() + breakpoint = element_counter.set_breakpoint(1) + yield + breakpoint.wait() + yield 0 + breakpoint.clear() + + # Everything should be perfectly split. + + elements = [2, 3] + expected_groups = [[(2, 0)], [(2, 1)], [(3, 0)], [(3, 1)], [(3, 2)]] + self.run_sdf_split_pipeline( + split_manager, elements, element_counter, expected_groups, is_drain=is_drain + ) + + def run_sdf_split_half(self, is_drain=False): + element_counter = ElementCounter() + is_first_bundle = True + + def split_manager(num_elements): + nonlocal is_first_bundle + if is_first_bundle and num_elements > 0: + is_first_bundle = False + breakpoint = element_counter.set_breakpoint(1) + yield + breakpoint.wait() + split1 = yield 0.5 + split2 = yield 0.5 + split3 = yield 0.5 + self.verify_channel_split(split1, 0, 1) + self.verify_channel_split(split2, -1, 1) + self.verify_channel_split(split3, -1, 1) + breakpoint.clear() + + elements = [4, 4] + expected_groups = [ + [(4, 0)], + [(4, 1)], + [(4, 2), (4, 3)], + [(4, 0), (4, 1), (4, 2), (4, 3)], + ] + + self.run_sdf_split_pipeline( + split_manager, elements, element_counter, expected_groups, is_drain=is_drain + ) + + def run_split_crazy_sdf(self, seed=None, is_drain=False): + if seed is None: + seed = random.randrange(1 << 20) + r = random.Random(seed) + element_counter = ElementCounter() + + def split_manager(num_elements): + if num_elements > 0: + element_counter.reset() + wait_for = r.randrange(num_elements) + breakpoint = element_counter.set_breakpoint(wait_for) + yield + breakpoint.wait() + yield r.random() + yield r.random() + breakpoint.clear() + + try: + elements = [r.randrange(5, 10) for _ in range(5)] + self.run_sdf_split_pipeline( + split_manager, elements, element_counter, is_drain=is_drain + ) + except Exception: + _LOGGER.error("test_split_crazy_sdf.seed = %s", seed) + raise + + @unittest.skip("SDF not yet supported") + def test_nosplit_sdf(self): + def split_manager(num_elements): + yield + + elements = [1, 2, 3] + expected_groups = [[(e, k) for k in range(e)] for e in elements] + self.run_sdf_split_pipeline( + split_manager, elements, ElementCounter(), expected_groups + ) + + @unittest.skip("SDF not yet supported") + def test_checkpoint_sdf(self): + self.run_sdf_checkpoint(is_drain=False) + + @unittest.skip("SDF not yet supported") + def test_checkpoint_draining_sdf(self): + self.run_sdf_checkpoint(is_drain=True) + + @unittest.skip("SDF not yet supported") + def test_split_half_sdf(self): + self.run_sdf_split_half(is_drain=False) + + @unittest.skip("SDF not yet supported") + def test_split_half_draining_sdf(self): + self.run_sdf_split_half(is_drain=True) + + @unittest.skip("SDF not yet supported") + def test_split_crazy_sdf(self, seed=None): + self.run_split_crazy_sdf(seed=seed, is_drain=False) + + @unittest.skip("SDF not yet supported") + def test_split_crazy_draining_sdf(self, seed=None): + self.run_split_crazy_sdf(seed=seed, is_drain=True) + + def run_sdf_split_pipeline( + self, split_manager, elements, element_counter, - expected_groups, - is_drain=is_drain) - - def run_split_crazy_sdf(self, seed=None, is_drain=False): - if seed is None: - seed = random.randrange(1 << 20) - r = random.Random(seed) - element_counter = ElementCounter() - - def split_manager(num_elements): - if num_elements > 0: - element_counter.reset() - wait_for = r.randrange(num_elements) - breakpoint = element_counter.set_breakpoint(wait_for) - yield - breakpoint.wait() - yield r.random() - yield r.random() - breakpoint.clear() - - try: - elements = [r.randrange(5, 10) for _ in range(5)] - self.run_sdf_split_pipeline( - split_manager, elements, element_counter, is_drain=is_drain) - except Exception: - _LOGGER.error('test_split_crazy_sdf.seed = %s', seed) - raise - - @unittest.skip('SDF not yet supported') - def test_nosplit_sdf(self): - def split_manager(num_elements): - yield - - elements = [1, 2, 3] - expected_groups = [[(e, k) for k in range(e)] for e in elements] - self.run_sdf_split_pipeline( - split_manager, elements, ElementCounter(), expected_groups) - - @unittest.skip('SDF not yet supported') - def test_checkpoint_sdf(self): - self.run_sdf_checkpoint(is_drain=False) - - @unittest.skip('SDF not yet supported') - def test_checkpoint_draining_sdf(self): - self.run_sdf_checkpoint(is_drain=True) - - @unittest.skip('SDF not yet supported') - def test_split_half_sdf(self): - self.run_sdf_split_half(is_drain=False) - - @unittest.skip('SDF not yet supported') - def test_split_half_draining_sdf(self): - self.run_sdf_split_half(is_drain=True) - - @unittest.skip('SDF not yet supported') - def test_split_crazy_sdf(self, seed=None): - self.run_split_crazy_sdf(seed=seed, is_drain=False) - - @unittest.skip('SDF not yet supported') - def test_split_crazy_draining_sdf(self, seed=None): - self.run_split_crazy_sdf(seed=seed, is_drain=True) - - def run_sdf_split_pipeline( - self, - split_manager, - elements, - element_counter, - expected_groups=None, - is_drain=False): - # Define an SDF that for each input x produces [(x, k) for k in range(x)]. - - class EnumerateProvider(beam.transforms.core.RestrictionProvider): - def initial_restriction(self, element): - return restriction_trackers.OffsetRange(0, element) + expected_groups=None, + is_drain=False, + ): + # Define an SDF that for each input x produces [(x, k) for k in range(x)]. + + class EnumerateProvider(beam.transforms.core.RestrictionProvider): + def initial_restriction(self, element): + return restriction_trackers.OffsetRange(0, element) + + def create_tracker(self, restriction): + return restriction_trackers.OffsetRestrictionTracker(restriction) + + def split(self, element, restriction): + # Don't do any initial splitting to simplify test. + return [restriction] + + def restriction_size(self, element, restriction): + return restriction.size() + + def is_bounded(self): + return True + + class EnumerateSdf(beam.DoFn): + def process( + self, + element, + restriction_tracker=beam.DoFn.RestrictionParam(EnumerateProvider()), + ): + to_emit = [] + cur = restriction_tracker.current_restriction().start + while restriction_tracker.try_claim(cur): + to_emit.append((element, cur)) + element_counter.increment() + cur += 1 + # Emitting in batches for tighter testing. + yield to_emit + + expected = [(e, k) for e in elements for k in range(e)] + + with fn_runner.split_manager("SDF", split_manager): + with self.create_pipeline(is_drain=is_drain) as p: + grouped = ( + p + | beam.Create(elements, reshuffle=False) + | "SDF" >> beam.ParDo(EnumerateSdf()) + ) + flat = grouped | beam.FlatMap(lambda x: x) + assert_that(flat, equal_to(expected)) + if expected_groups: + assert_that( + grouped, equal_to(expected_groups), label="CheckGrouped" + ) + + def verify_channel_split(self, split_result, last_primary, first_residual): + self.assertEqual(1, len(split_result.channel_splits), split_result) + (channel_split,) = split_result.channel_splits + self.assertEqual(last_primary, channel_split.last_primary_element) + self.assertEqual(first_residual, channel_split.first_residual_element) + # There should be a primary and residual application for each element + # not covered above. + self.assertEqual( + first_residual - last_primary - 1, + len(split_result.primary_roots), + split_result.primary_roots, + ) + self.assertEqual( + first_residual - last_primary - 1, + len(split_result.residual_roots), + split_result.residual_roots, + ) - def create_tracker(self, restriction): - return restriction_trackers.OffsetRestrictionTracker(restriction) - def split(self, element, restriction): - # Don't do any initial splitting to simplify test. - return [restriction] +class ElementCounter(object): + """Used to wait until a certain number of elements are seen.""" - def restriction_size(self, element, restriction): - return restriction.size() + def __init__(self): + self._cv = threading.Condition() + self.reset() - def is_bounded(self): - return True - - class EnumerateSdf(beam.DoFn): - def process( - self, - element, - restriction_tracker=beam.DoFn.RestrictionParam(EnumerateProvider())): - to_emit = [] - cur = restriction_tracker.current_restriction().start - while restriction_tracker.try_claim(cur): - to_emit.append((element, cur)) - element_counter.increment() - cur += 1 - # Emitting in batches for tighter testing. - yield to_emit - - expected = [(e, k) for e in elements for k in range(e)] - - with fn_runner.split_manager('SDF', split_manager): - with self.create_pipeline(is_drain=is_drain) as p: - grouped = ( - p - | beam.Create(elements, reshuffle=False) - | 'SDF' >> beam.ParDo(EnumerateSdf())) - flat = grouped | beam.FlatMap(lambda x: x) - assert_that(flat, equal_to(expected)) - if expected_groups: - assert_that(grouped, equal_to(expected_groups), label='CheckGrouped') - - def verify_channel_split(self, split_result, last_primary, first_residual): - self.assertEqual(1, len(split_result.channel_splits), split_result) - channel_split, = split_result.channel_splits - self.assertEqual(last_primary, channel_split.last_primary_element) - self.assertEqual(first_residual, channel_split.first_residual_element) - # There should be a primary and residual application for each element - # not covered above. - self.assertEqual( - first_residual - last_primary - 1, - len(split_result.primary_roots), - split_result.primary_roots) - self.assertEqual( - first_residual - last_primary - 1, - len(split_result.residual_roots), - split_result.residual_roots) + def reset(self): + with self._cv: + self._breakpoints = collections.defaultdict(list) + self._count = 0 + def increment(self): + with self._cv: + self._count += 1 + self._cv.notify_all() + breakpoints = list(self._breakpoints[self._count]) + for breakpoint in breakpoints: + breakpoint.wait() -class ElementCounter(object): - """Used to wait until a certain number of elements are seen.""" - def __init__(self): - self._cv = threading.Condition() - self.reset() - - def reset(self): - with self._cv: - self._breakpoints = collections.defaultdict(list) - self._count = 0 - - def increment(self): - with self._cv: - self._count += 1 - self._cv.notify_all() - breakpoints = list(self._breakpoints[self._count]) - for breakpoint in breakpoints: - breakpoint.wait() - - def set_breakpoint(self, value): - with self._cv: - event = threading.Event() - self._breakpoints[value].append(event) - - class Breakpoint(object): - @staticmethod - def wait(timeout=10): + def set_breakpoint(self, value): with self._cv: - start = time.time() - while self._count < value: - elapsed = time.time() - start - if elapsed > timeout: - raise RuntimeError('Timed out waiting for %s' % value) - self._cv.wait(timeout - elapsed) + event = threading.Event() + self._breakpoints[value].append(event) - @staticmethod - def clear(): - event.set() + class Breakpoint(object): + @staticmethod + def wait(timeout=10): + with self._cv: + start = time.time() + while self._count < value: + elapsed = time.time() - start + if elapsed > timeout: + raise RuntimeError("Timed out waiting for %s" % value) + self._cv.wait(timeout - elapsed) - return Breakpoint() + @staticmethod + def clear(): + event.set() - def __reduce__(self): - # Ensure we get the same element back through a pickling round-trip. - name = uuid.uuid4().hex - _pickled_element_counters[name] = self - return _unpickle_element_counter, (name, ) + return Breakpoint() + + def __reduce__(self): + # Ensure we get the same element back through a pickling round-trip. + name = uuid.uuid4().hex + _pickled_element_counters[name] = self + return _unpickle_element_counter, (name,) _pickled_element_counters = {} # type: Dict[str, ElementCounter] def _unpickle_element_counter(name): - return _pickled_element_counters[name] + return _pickled_element_counters[name] class EventRecorder(object): - """Used to be registered as a callback in bundle finalization. - - The reason why records are written into a tmp file is, the in-memory dataset - cannot keep callback records when passing into one DoFn. - """ - def __init__(self, tmp_dir): - self.tmp_dir = os.path.join(tmp_dir, uuid.uuid4().hex) - os.mkdir(self.tmp_dir) - - def record(self, content): - file_path = os.path.join(self.tmp_dir, uuid.uuid4().hex + '.txt') - with open(file_path, 'w') as f: - f.write(content) - - def events(self): - content = [] - record_files = [ - f for f in os.listdir(self.tmp_dir) - if os.path.isfile(os.path.join(self.tmp_dir, f)) - ] - for file in record_files: - with open(os.path.join(self.tmp_dir, file), 'r') as f: - content.append(f.read()) - return sorted(content) - - def cleanup(self): - shutil.rmtree(self.tmp_dir) + """Used to be registered as a callback in bundle finalization. + + The reason why records are written into a tmp file is, the in-memory dataset + cannot keep callback records when passing into one DoFn. + """ + + def __init__(self, tmp_dir): + self.tmp_dir = os.path.join(tmp_dir, uuid.uuid4().hex) + os.mkdir(self.tmp_dir) + + def record(self, content): + file_path = os.path.join(self.tmp_dir, uuid.uuid4().hex + ".txt") + with open(file_path, "w") as f: + f.write(content) + + def events(self): + content = [] + record_files = [ + f + for f in os.listdir(self.tmp_dir) + if os.path.isfile(os.path.join(self.tmp_dir, f)) + ] + for file in record_files: + with open(os.path.join(self.tmp_dir, file), "r") as f: + content.append(f.read()) + return sorted(content) + + def cleanup(self): + shutil.rmtree(self.tmp_dir) class ExpandStringsProvider(beam.transforms.core.RestrictionProvider): - """A RestrictionProvider that used for sdf related tests.""" - def initial_restriction(self, element): - return restriction_trackers.OffsetRange(0, len(element)) + """A RestrictionProvider that used for sdf related tests.""" - def create_tracker(self, restriction): - return restriction_trackers.OffsetRestrictionTracker(restriction) + def initial_restriction(self, element): + return restriction_trackers.OffsetRange(0, len(element)) - def split(self, element, restriction): - desired_bundle_size = restriction.size() // 2 - return restriction.split(desired_bundle_size) + def create_tracker(self, restriction): + return restriction_trackers.OffsetRestrictionTracker(restriction) + + def split(self, element, restriction): + desired_bundle_size = restriction.size() // 2 + return restriction.split(desired_bundle_size) - def restriction_size(self, element, restriction): - return restriction.size() + def restriction_size(self, element, restriction): + return restriction.size() -class UnboundedOffsetRestrictionTracker( - restriction_trackers.OffsetRestrictionTracker): - def is_bounded(self): - return False +class UnboundedOffsetRestrictionTracker(restriction_trackers.OffsetRestrictionTracker): + def is_bounded(self): + return False class OffsetRangeProvider(beam.transforms.core.RestrictionProvider): - def __init__(self, use_bounded_offset_range, checkpoint_only=False): - self.use_bounded_offset_range = use_bounded_offset_range - self.checkpoint_only = checkpoint_only + def __init__(self, use_bounded_offset_range, checkpoint_only=False): + self.use_bounded_offset_range = use_bounded_offset_range + self.checkpoint_only = checkpoint_only - def initial_restriction(self, element): - return restriction_trackers.OffsetRange(0, element) + def initial_restriction(self, element): + return restriction_trackers.OffsetRange(0, element) - def create_tracker(self, restriction): - if self.checkpoint_only: + def create_tracker(self, restriction): + if self.checkpoint_only: - class CheckpointOnlyOffsetRestrictionTracker( - restriction_trackers.OffsetRestrictionTracker): - def try_split(self, unused_fraction_of_remainder): - return super().try_split(0.0) + class CheckpointOnlyOffsetRestrictionTracker( + restriction_trackers.OffsetRestrictionTracker + ): + def try_split(self, unused_fraction_of_remainder): + return super().try_split(0.0) - return CheckpointOnlyOffsetRestrictionTracker(restriction) - if self.use_bounded_offset_range: - return restriction_trackers.OffsetRestrictionTracker(restriction) - return UnboundedOffsetRestrictionTracker(restriction) + return CheckpointOnlyOffsetRestrictionTracker(restriction) + if self.use_bounded_offset_range: + return restriction_trackers.OffsetRestrictionTracker(restriction) + return UnboundedOffsetRestrictionTracker(restriction) - def split(self, element, restriction): - return [restriction] + def split(self, element, restriction): + return [restriction] - def restriction_size(self, element, restriction): - return restriction.size() + def restriction_size(self, element, restriction): + return restriction.size() class OffsetRangeProviderWithTruncate(OffsetRangeProvider): - def __init__(self): - super().__init__(True) + def __init__(self): + super().__init__(True) - def truncate(self, element, restriction): - return restriction_trackers.OffsetRange( - restriction.start, restriction.stop // 2) + def truncate(self, element, restriction): + return restriction_trackers.OffsetRange( + restriction.start, restriction.stop // 2 + ) # TODO(robertwb): Why does pickling break when this is inlined? class CustomMergingWindowFn(window.WindowFn): - def assign(self, assign_context): - return [ - window.IntervalWindow( - assign_context.timestamp, assign_context.timestamp + 1000) - ] - - def merge(self, merge_context): - evens = [w for w in merge_context.windows if w.start % 2 == 0] - if evens: - merge_context.merge( - evens, - window.IntervalWindow( - min(w.start for w in evens), max(w.end for w in evens))) - - def get_window_coder(self): - return coders.IntervalWindowCoder() + def assign(self, assign_context): + return [ + window.IntervalWindow( + assign_context.timestamp, assign_context.timestamp + 1000 + ) + ] + + def merge(self, merge_context): + evens = [w for w in merge_context.windows if w.start % 2 == 0] + if evens: + merge_context.merge( + evens, + window.IntervalWindow( + min(w.start for w in evens), max(w.end for w in evens) + ), + ) + + def get_window_coder(self): + return coders.IntervalWindowCoder() class ExpectingSideInputsFn(beam.DoFn): - def __init__(self, name): - self._name = name + def __init__(self, name): + self._name = name + + def default_label(self): + return self._name - def default_label(self): - return self._name + def process(self, element, *side_inputs): + logging.info("Running %s (side inputs: %s)", self._name, side_inputs) + if not all(list(s) for s in side_inputs): + raise ValueError(f"Missing data in side input {side_inputs}") + yield self._name - def process(self, element, *side_inputs): - logging.info('Running %s (side inputs: %s)', self._name, side_inputs) - if not all(list(s) for s in side_inputs): - raise ValueError(f'Missing data in side input {side_inputs}') - yield self._name logging.getLogger().setLevel(logging.INFO) -if __name__ == '__main__': - logging.getLogger().setLevel(logging.INFO) - unittest.main() +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + unittest.main() diff --git a/ray_beam_runner/portability/state.py b/ray_beam_runner/portability/state.py index 07f9c2d..0f8722f 100644 --- a/ray_beam_runner/portability/state.py +++ b/ray_beam_runner/portability/state.py @@ -28,75 +28,79 @@ @ray.remote class _ActorStateManager: - def __init__(self): - self._data = collections.defaultdict(lambda : []) - - def get_raw( - self, - bundle_id: str, - state_key: str, - continuation_token: Optional[bytes] = None, - ) -> Tuple[bytes, Optional[bytes]]: - if continuation_token: - continuation_token = int(continuation_token) - else: - continuation_token = 0 - - new_cont_token = continuation_token + 1 - if len(self._data[(bundle_id, state_key)]) == new_cont_token: - return self._data[(bundle_id, state_key)][continuation_token], None - else: - return (self._data[(bundle_id, state_key)][continuation_token], - str(continuation_token + 1).encode('utf8')) - - def append_raw( - self, - bundle_id: str, - state_key: str, - data: bytes - ): - self._data[(bundle_id, state_key)].append(data) - - def clear(self, bundle_id: str, state_key: str): - self._data[(bundle_id, state_key)] = [] + def __init__(self): + self._data = collections.defaultdict(lambda: []) + + def get_raw( + self, + bundle_id: str, + state_key: str, + continuation_token: Optional[bytes] = None, + ) -> Tuple[bytes, Optional[bytes]]: + if continuation_token: + continuation_token = int(continuation_token) + else: + continuation_token = 0 + + new_cont_token = continuation_token + 1 + if len(self._data[(bundle_id, state_key)]) == new_cont_token: + return self._data[(bundle_id, state_key)][continuation_token], None + else: + return ( + self._data[(bundle_id, state_key)][continuation_token], + str(continuation_token + 1).encode("utf8"), + ) + + def append_raw(self, bundle_id: str, state_key: str, data: bytes): + self._data[(bundle_id, state_key)].append(data) + + def clear(self, bundle_id: str, state_key: str): + self._data[(bundle_id, state_key)] = [] class RayStateManager(sdk_worker.StateHandler): - def __init__(self, state_actor: Optional[_ActorStateManager] = None): - self._state_actor = state_actor or _ActorStateManager.remote() - self._instruction_id: Optional[str] = None - - @staticmethod - def _to_key(state_key: beam_fn_api_pb2.StateKey): - return state_key.SerializeToString() - - def get_raw( - self, - state_key, # type: beam_fn_api_pb2.StateKey - continuation_token=None # type: Optional[bytes] - ) -> Tuple[bytes, Optional[bytes]]: - assert self._instruction_id is not None - return ray.get( - self._state_actor.get_raw.remote(self._instruction_id, RayStateManager._to_key(state_key), continuation_token)) - - def append_raw( - self, - state_key: beam_fn_api_pb2.StateKey, - data: bytes - ) -> sdk_worker._Future: - assert self._instruction_id is not None - return self._state_actor.append_raw.remote(self._instruction_id, RayStateManager._to_key(state_key), data) - - def clear(self, state_key: beam_fn_api_pb2.StateKey) -> sdk_worker._Future: - # TODO(pabloem): Does the ray future work as a replacement of Beam _Future? - assert self._instruction_id is not None - return self._state_actor.clear.remote(self._instruction_id, RayStateManager._to_key(state_key)) - - @contextlib.contextmanager - def process_instruction_id(self, bundle_id: str) -> Iterator[None]: - self._instruction_id = bundle_id - yield - self._instruction_id = None - - def done(self): - pass \ No newline at end of file + def __init__(self, state_actor: Optional[_ActorStateManager] = None): + self._state_actor = state_actor or _ActorStateManager.remote() + self._instruction_id: Optional[str] = None + + @staticmethod + def _to_key(state_key: beam_fn_api_pb2.StateKey): + return state_key.SerializeToString() + + def get_raw( + self, + state_key, # type: beam_fn_api_pb2.StateKey + continuation_token=None, # type: Optional[bytes] + ) -> Tuple[bytes, Optional[bytes]]: + assert self._instruction_id is not None + return ray.get( + self._state_actor.get_raw.remote( + self._instruction_id, + RayStateManager._to_key(state_key), + continuation_token, + ) + ) + + def append_raw( + self, state_key: beam_fn_api_pb2.StateKey, data: bytes + ) -> sdk_worker._Future: + assert self._instruction_id is not None + return self._state_actor.append_raw.remote( + self._instruction_id, RayStateManager._to_key(state_key), data + ) + + def clear(self, state_key: beam_fn_api_pb2.StateKey) -> sdk_worker._Future: + # TODO(pabloem): Does the ray future work as a replacement of Beam _Future? + assert self._instruction_id is not None + return self._state_actor.clear.remote( + self._instruction_id, RayStateManager._to_key(state_key) + ) + + @contextlib.contextmanager + def process_instruction_id(self, bundle_id: str) -> Iterator[None]: + self._instruction_id = bundle_id + yield + self._instruction_id = None + + def done(self): + pass