From c1c63f1163e1ba53a68210b8b25082126ebf3170 Mon Sep 17 00:00:00 2001 From: Ionel Gog Date: Sat, 11 Jan 2025 20:41:15 -0800 Subject: [PATCH] Add a utility pass for writing atom programs and main IFRT func to files. PiperOrigin-RevId: 714569854 --- xla/python/ifrt/ir/tests/BUILD | 1 + xla/python/ifrt/ir/tests/ifrt-opt.cc | 3 + .../ifrt_compile_and_propagate_shardings.mlir | 2 +- .../ir/tests/ifrt_compile_atom_program.mlir | 2 +- .../ifrt_duplicated_callee_elimination.mlir | 2 +- ...rt_lower_atom_program_metadata_to_xla.mlir | 2 +- .../ifrt_lower_mpmd_reshard_to_call.mlir | 2 +- .../ifrt/ir/tests/ifrt_merge_reshards.mlir | 2 +- .../ifrt_outline_atom_program_to_module.mlir | 2 +- .../ifrt_populate_atom_program_metadata.mlir | 2 +- ...precompile_atom_program_preprocessing.mlir | 2 +- ...ifrt_remove_attrs_from_other_dialects.mlir | 2 +- .../ifrt/ir/tests/ifrt_remove_ifrt_attrs.mlir | 2 +- .../ir/tests/ifrt_reshard_to_copy_arrays.mlir | 2 +- .../ifrt_verify_device_type_consistency.mlir | 2 +- .../ifrt/ir/tests/ifrt_verify_donation.mlir | 2 +- .../tests/ifrt_verify_sharding_specified.mlir | 2 +- xla/python/ifrt/ir/tests/spmd_expansion.mlir | 2 +- .../ir/tests/spmd_interface_verification.mlir | 2 +- xla/python/ifrt/ir/tests/verify_array.mlir | 2 +- xla/python/ifrt/ir/tests/verify_assemble.mlir | 2 +- xla/python/ifrt/ir/tests/verify_attrs.mlir | 2 +- xla/python/ifrt/ir/tests/verify_call.mlir | 2 +- .../tests/verify_call_loaded_executable.mlir | 2 +- .../ifrt/ir/tests/verify_copy_arrays.mlir | 2 +- .../ifrt/ir/tests/verify_disassemble.mlir | 2 +- .../ir/tests/verify_loaded_executable.mlir | 2 +- .../ifrt/ir/tests/verify_remap_arrays.mlir | 2 +- xla/python/ifrt/ir/tests/verify_reshard.mlir | 2 +- .../vifrt/ifrt_legalize_to_vifrt.0_1_0.mlir | 2 +- .../tests/vifrt/ifrt_legalize_to_vifrt.mlir | 2 +- .../vifrt/ifrt_legalize_to_vifrt_invalid.mlir | 2 +- xla/python/ifrt/ir/transforms/BUILD | 3 + .../ifrt_dump_atom_programs_pass.cc | 117 ++++++++++++++++++ xla/python/ifrt/ir/transforms/passes.h | 3 + xla/python/ifrt/ir/transforms/passes.td | 10 ++ 36 files changed, 167 insertions(+), 30 deletions(-) create mode 100644 xla/python/ifrt/ir/transforms/ifrt_dump_atom_programs_pass.cc diff --git a/xla/python/ifrt/ir/tests/BUILD b/xla/python/ifrt/ir/tests/BUILD index 46b6fc103271ee..60b86dcf3a3d86 100644 --- a/xla/python/ifrt/ir/tests/BUILD +++ b/xla/python/ifrt/ir/tests/BUILD @@ -51,6 +51,7 @@ xla_cc_binary( "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:IR", "@llvm-project//mlir:MlirOptLib", + "@tsl//tsl/platform:platform_port", ], ) diff --git a/xla/python/ifrt/ir/tests/ifrt-opt.cc b/xla/python/ifrt/ir/tests/ifrt-opt.cc index 593e9737c5812d..c8038af92ebe8f 100644 --- a/xla/python/ifrt/ir/tests/ifrt-opt.cc +++ b/xla/python/ifrt/ir/tests/ifrt-opt.cc @@ -39,6 +39,7 @@ limitations under the License. #include "xla/python/ifrt/mock.h" #include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/support/module_parsing.h" +#include "tsl/platform/init_main.h" namespace xla { namespace ifrt { @@ -117,6 +118,8 @@ class TestChildExecutableCompiler : public AtomProgramCompiler { } // namespace xla int main(int argc, char** argv) { + tsl::port::InitMain(argv[0], &argc, &argv); + std::shared_ptr compiler = std::make_shared(); auto compile_options = std::make_shared, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> diff --git a/xla/python/ifrt/ir/tests/ifrt_compile_atom_program.mlir b/xla/python/ifrt/ir/tests/ifrt_compile_atom_program.mlir index 22257730e01d5e..fa08961d03ff57 100644 --- a/xla/python/ifrt/ir/tests/ifrt_compile_atom_program.mlir +++ b/xla/python/ifrt/ir/tests/ifrt_compile_atom_program.mlir @@ -1,4 +1,4 @@ -// RUN: ifrt-opt %s -ifrt-compile-atom-program -split-input-file | FileCheck %s +// RUN: ifrt-opt %s -- -ifrt-compile-atom-program -split-input-file | FileCheck %s // CHECK-LABEL: @call_hlo !array = !ifrt.array, diff --git a/xla/python/ifrt/ir/tests/ifrt_duplicated_callee_elimination.mlir b/xla/python/ifrt/ir/tests/ifrt_duplicated_callee_elimination.mlir index 5cf62e23e0e59e..94291fcc9e5d4a 100644 --- a/xla/python/ifrt/ir/tests/ifrt_duplicated_callee_elimination.mlir +++ b/xla/python/ifrt/ir/tests/ifrt_duplicated_callee_elimination.mlir @@ -1,4 +1,4 @@ -// RUN: ifrt-opt %s -ifrt-duplicated-callee-elimination | FileCheck %s +// RUN: ifrt-opt %s -- -ifrt-duplicated-callee-elimination | FileCheck %s // CHECK-LABEL: @main func.func @main(%arg0: !ifrt.array, diff --git a/xla/python/ifrt/ir/tests/ifrt_lower_atom_program_metadata_to_xla.mlir b/xla/python/ifrt/ir/tests/ifrt_lower_atom_program_metadata_to_xla.mlir index 52b1a2ea38aa73..d76ce277bca94c 100644 --- a/xla/python/ifrt/ir/tests/ifrt_lower_atom_program_metadata_to_xla.mlir +++ b/xla/python/ifrt/ir/tests/ifrt_lower_atom_program_metadata_to_xla.mlir @@ -1,4 +1,4 @@ -// RUN: ifrt-opt %s -ifrt-lower-atom-program-metadata-to-xla -split-input-file -verify-diagnostics | FileCheck %s +// RUN: ifrt-opt %s -- -ifrt-lower-atom-program-metadata-to-xla -split-input-file -verify-diagnostics | FileCheck %s // CHECK-LABEL: @arg_metadata module @arg_metadata attributes {ifrt.num_devices = 2} { diff --git a/xla/python/ifrt/ir/tests/ifrt_lower_mpmd_reshard_to_call.mlir b/xla/python/ifrt/ir/tests/ifrt_lower_mpmd_reshard_to_call.mlir index fe3331a9fa306e..64059a07a5c5f3 100644 --- a/xla/python/ifrt/ir/tests/ifrt_lower_mpmd_reshard_to_call.mlir +++ b/xla/python/ifrt/ir/tests/ifrt_lower_mpmd_reshard_to_call.mlir @@ -1,4 +1,4 @@ -// RUN: ifrt-opt %s -ifrt-lower-mpmd-reshard-to-call -split-input-file -verify-diagnostics | FileCheck %s +// RUN: ifrt-opt %s -- -ifrt-lower-mpmd-reshard-to-call -split-input-file -verify-diagnostics | FileCheck %s !array0 = !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> diff --git a/xla/python/ifrt/ir/tests/ifrt_merge_reshards.mlir b/xla/python/ifrt/ir/tests/ifrt_merge_reshards.mlir index 4f8f0e20bc60cc..45239d594260c1 100644 --- a/xla/python/ifrt/ir/tests/ifrt_merge_reshards.mlir +++ b/xla/python/ifrt/ir/tests/ifrt_merge_reshards.mlir @@ -1,4 +1,4 @@ -// RUN: ifrt-opt %s -ifrt-merge-reshards | FileCheck %s +// RUN: ifrt-opt %s -- -ifrt-merge-reshards | FileCheck %s #sharding = #ifrt.sharding_param<2 to [0] on 2> !array0 = !ifrt.array, #sharding, [0,1]> diff --git a/xla/python/ifrt/ir/tests/ifrt_outline_atom_program_to_module.mlir b/xla/python/ifrt/ir/tests/ifrt_outline_atom_program_to_module.mlir index c963b4ccb7a604..8712225499b638 100644 --- a/xla/python/ifrt/ir/tests/ifrt_outline_atom_program_to_module.mlir +++ b/xla/python/ifrt/ir/tests/ifrt_outline_atom_program_to_module.mlir @@ -1,4 +1,4 @@ -// RUN: ifrt-opt %s -ifrt-outline-atom-program-to-module -split-input-file -verify-diagnostics | FileCheck %s +// RUN: ifrt-opt %s -- -ifrt-outline-atom-program-to-module -split-input-file -verify-diagnostics | FileCheck %s !array = !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0,1]> diff --git a/xla/python/ifrt/ir/tests/ifrt_populate_atom_program_metadata.mlir b/xla/python/ifrt/ir/tests/ifrt_populate_atom_program_metadata.mlir index 790091a44a4953..6a47f1c0fb89f9 100644 --- a/xla/python/ifrt/ir/tests/ifrt_populate_atom_program_metadata.mlir +++ b/xla/python/ifrt/ir/tests/ifrt_populate_atom_program_metadata.mlir @@ -1,4 +1,4 @@ -// RUN: ifrt-opt %s -ifrt-populate-atom-program-metadata -ifrt-duplicated-callee-elimination -symbol-dce -split-input-file | FileCheck %s +// RUN: ifrt-opt %s -- -ifrt-populate-atom-program-metadata -ifrt-duplicated-callee-elimination -symbol-dce -split-input-file | FileCheck %s !array = !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0,1], diff --git a/xla/python/ifrt/ir/tests/ifrt_precompile_atom_program_preprocessing.mlir b/xla/python/ifrt/ir/tests/ifrt_precompile_atom_program_preprocessing.mlir index 9cbdef58dc2d18..c4b862d9aa95e0 100644 --- a/xla/python/ifrt/ir/tests/ifrt_precompile_atom_program_preprocessing.mlir +++ b/xla/python/ifrt/ir/tests/ifrt_precompile_atom_program_preprocessing.mlir @@ -1,4 +1,4 @@ -// RUN: ifrt-opt %s -ifrt-precompile-atom-program-preprocessing='platform_names=tpu,tpu' -split-input-file -verify-diagnostics | FileCheck %s +// RUN: ifrt-opt %s -- -ifrt-precompile-atom-program-preprocessing='platform_names=tpu,tpu' -split-input-file -verify-diagnostics | FileCheck %s #sharding = #ifrt.sharding_param<2x1 to [0] on 2> !array = !ifrt.array, #sharding, [0, 1]> diff --git a/xla/python/ifrt/ir/tests/ifrt_remove_attrs_from_other_dialects.mlir b/xla/python/ifrt/ir/tests/ifrt_remove_attrs_from_other_dialects.mlir index 25737223246979..415c2170ef28f7 100644 --- a/xla/python/ifrt/ir/tests/ifrt_remove_attrs_from_other_dialects.mlir +++ b/xla/python/ifrt/ir/tests/ifrt_remove_attrs_from_other_dialects.mlir @@ -1,4 +1,4 @@ -// RUN: ifrt-opt %s -ifrt-remove-attrs-from-other-dialects -split-input-file | FileCheck %s +// RUN: ifrt-opt %s -- -ifrt-remove-attrs-from-other-dialects -split-input-file | FileCheck %s !array = !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0,1]> diff --git a/xla/python/ifrt/ir/tests/ifrt_remove_ifrt_attrs.mlir b/xla/python/ifrt/ir/tests/ifrt_remove_ifrt_attrs.mlir index 381a7b1d77ec9c..3d0b6431cc81f5 100644 --- a/xla/python/ifrt/ir/tests/ifrt_remove_ifrt_attrs.mlir +++ b/xla/python/ifrt/ir/tests/ifrt_remove_ifrt_attrs.mlir @@ -1,4 +1,4 @@ -// RUN: ifrt-opt %s -ifrt-remove-ifrt-attrs | FileCheck %s +// RUN: ifrt-opt %s -- -ifrt-remove-ifrt-attrs | FileCheck %s // CHECK-LABEL: @ifrt_attributes_are_removed // CHECK-NOT: ifrt diff --git a/xla/python/ifrt/ir/tests/ifrt_reshard_to_copy_arrays.mlir b/xla/python/ifrt/ir/tests/ifrt_reshard_to_copy_arrays.mlir index cf8231b67050a4..dcaa241de1e3ab 100644 --- a/xla/python/ifrt/ir/tests/ifrt_reshard_to_copy_arrays.mlir +++ b/xla/python/ifrt/ir/tests/ifrt_reshard_to_copy_arrays.mlir @@ -1,4 +1,4 @@ -// RUN: ifrt-opt %s -ifrt-reshard-to-copy-arrays -verify-diagnostics -split-input-file | FileCheck %s +// RUN: ifrt-opt %s -- -ifrt-reshard-to-copy-arrays -verify-diagnostics -split-input-file | FileCheck %s !array0 = !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0,1]> diff --git a/xla/python/ifrt/ir/tests/ifrt_verify_device_type_consistency.mlir b/xla/python/ifrt/ir/tests/ifrt_verify_device_type_consistency.mlir index 426a7b87fa9b5e..4bd9473e0d1c73 100644 --- a/xla/python/ifrt/ir/tests/ifrt_verify_device_type_consistency.mlir +++ b/xla/python/ifrt/ir/tests/ifrt_verify_device_type_consistency.mlir @@ -1,4 +1,4 @@ -// RUN: ifrt-opt %s -ifrt-verify-device-type-consistency='platform_names=tpu,tpu,cpu,tpu,cpu,cuda,cuda' -split-input-file -verify-diagnostics | FileCheck %s +// RUN: ifrt-opt %s -- -ifrt-verify-device-type-consistency='platform_names=tpu,tpu,cpu,tpu,cpu,cuda,cuda' -split-input-file -verify-diagnostics | FileCheck %s // CHECK-LABEL: @good_call_multiple #sharding = #ifrt.sharding_param<2 to [0] on 2> diff --git a/xla/python/ifrt/ir/tests/ifrt_verify_donation.mlir b/xla/python/ifrt/ir/tests/ifrt_verify_donation.mlir index 92bed2748c2188..6e21390cb9cae8 100644 --- a/xla/python/ifrt/ir/tests/ifrt_verify_donation.mlir +++ b/xla/python/ifrt/ir/tests/ifrt_verify_donation.mlir @@ -1,4 +1,4 @@ -// RUN: ifrt-opt %s -ifrt-verify-donation -split-input-file -verify-diagnostics | FileCheck %s +// RUN: ifrt-opt %s -- -ifrt-verify-donation -split-input-file -verify-diagnostics | FileCheck %s !array0 = !ifrt.array, #ifrt.sharding_param<2 to [0] on 2>, [0, 1]> diff --git a/xla/python/ifrt/ir/tests/ifrt_verify_sharding_specified.mlir b/xla/python/ifrt/ir/tests/ifrt_verify_sharding_specified.mlir index f8b37fa87d2429..72d76050fba09c 100644 --- a/xla/python/ifrt/ir/tests/ifrt_verify_sharding_specified.mlir +++ b/xla/python/ifrt/ir/tests/ifrt_verify_sharding_specified.mlir @@ -1,4 +1,4 @@ -// RUN: ifrt-opt %s -ifrt-verify-sharding-specified -split-input-file -verify-diagnostics | FileCheck %s +// RUN: ifrt-opt %s -- -ifrt-verify-sharding-specified -split-input-file -verify-diagnostics | FileCheck %s // CHECK-LABEL: @good_arrays #sharding = #ifrt.sharding_param<2 to [0] on 2> diff --git a/xla/python/ifrt/ir/tests/spmd_expansion.mlir b/xla/python/ifrt/ir/tests/spmd_expansion.mlir index 4fef0876dc8bb8..b24690d4419720 100644 --- a/xla/python/ifrt/ir/tests/spmd_expansion.mlir +++ b/xla/python/ifrt/ir/tests/spmd_expansion.mlir @@ -1,4 +1,4 @@ -// RUN: ifrt-opt %s -spmd-expansion -split-input-file -verify-diagnostics | FileCheck %s +// RUN: ifrt-opt %s -- -spmd-expansion -split-input-file -verify-diagnostics | FileCheck %s #device = #ifrt #sharding = #ifrt.sharding_param<2x1 to [0] on 2> diff --git a/xla/python/ifrt/ir/tests/spmd_interface_verification.mlir b/xla/python/ifrt/ir/tests/spmd_interface_verification.mlir index 0bd8ef268e0fa1..096bfa0f9701a7 100644 --- a/xla/python/ifrt/ir/tests/spmd_interface_verification.mlir +++ b/xla/python/ifrt/ir/tests/spmd_interface_verification.mlir @@ -1,4 +1,4 @@ -// RUN: ifrt-opt %s -split-input-file -spmd-expandable-interface-verification='excluded-dialects=arith' -verify-diagnostics +// RUN: ifrt-opt %s -- -split-input-file -spmd-expandable-interface-verification='excluded-dialects=arith' -verify-diagnostics module @good_return_only { func.func @main( diff --git a/xla/python/ifrt/ir/tests/verify_array.mlir b/xla/python/ifrt/ir/tests/verify_array.mlir index edab3b639a4300..22e9acd77b77f4 100644 --- a/xla/python/ifrt/ir/tests/verify_array.mlir +++ b/xla/python/ifrt/ir/tests/verify_array.mlir @@ -1,4 +1,4 @@ -// RUN: ifrt-opt %s -split-input-file -verify-diagnostics +// RUN: ifrt-opt %s -- -split-input-file -verify-diagnostics func.func @good_array() { /// Dim 0 of the tensor is sharded into 4 slices. diff --git a/xla/python/ifrt/ir/tests/verify_assemble.mlir b/xla/python/ifrt/ir/tests/verify_assemble.mlir index 9c3416961102d2..7440f3200d7ef7 100644 --- a/xla/python/ifrt/ir/tests/verify_assemble.mlir +++ b/xla/python/ifrt/ir/tests/verify_assemble.mlir @@ -1,4 +1,4 @@ -// RUN: ifrt-opt %s -split-input-file -verify-diagnostics +// RUN: ifrt-opt %s -- -split-input-file -verify-diagnostics func.func @good_assemble( %arg0: !ifrt.array, diff --git a/xla/python/ifrt/ir/tests/verify_attrs.mlir b/xla/python/ifrt/ir/tests/verify_attrs.mlir index 45c18e7149367b..6e26f442fc5d98 100644 --- a/xla/python/ifrt/ir/tests/verify_attrs.mlir +++ b/xla/python/ifrt/ir/tests/verify_attrs.mlir @@ -1,4 +1,4 @@ -// RUN: ifrt-opt %s -split-input-file -verify-diagnostics +// RUN: ifrt-opt %s -- -split-input-file -verify-diagnostics func.func @good_function_attr() attributes {ifrt.function} { return diff --git a/xla/python/ifrt/ir/tests/verify_call.mlir b/xla/python/ifrt/ir/tests/verify_call.mlir index 202724e44496a0..cc8dc464efcc9c 100644 --- a/xla/python/ifrt/ir/tests/verify_call.mlir +++ b/xla/python/ifrt/ir/tests/verify_call.mlir @@ -1,4 +1,4 @@ -// RUN: ifrt-opt %s -split-input-file -verify-diagnostics +// RUN: ifrt-opt %s -- -split-input-file -verify-diagnostics func.func @good_call( %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, diff --git a/xla/python/ifrt/ir/tests/verify_call_loaded_executable.mlir b/xla/python/ifrt/ir/tests/verify_call_loaded_executable.mlir index e41add06877c60..e181b8c8a9d3d3 100644 --- a/xla/python/ifrt/ir/tests/verify_call_loaded_executable.mlir +++ b/xla/python/ifrt/ir/tests/verify_call_loaded_executable.mlir @@ -1,4 +1,4 @@ -// RUN: ifrt-opt %s -split-input-file -verify-diagnostics +// RUN: ifrt-opt %s -- -split-input-file -verify-diagnostics func.func @good( %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, diff --git a/xla/python/ifrt/ir/tests/verify_copy_arrays.mlir b/xla/python/ifrt/ir/tests/verify_copy_arrays.mlir index 0c3d2a2597dc1e..57784404ef3e31 100644 --- a/xla/python/ifrt/ir/tests/verify_copy_arrays.mlir +++ b/xla/python/ifrt/ir/tests/verify_copy_arrays.mlir @@ -1,4 +1,4 @@ -// RUN: ifrt-opt %s -split-input-file -verify-diagnostics +// RUN: ifrt-opt %s -- -split-input-file -verify-diagnostics !array0 = !ifrt.array, #ifrt.sharding_param<1x2 to [0] on 2>, [0,1]> diff --git a/xla/python/ifrt/ir/tests/verify_disassemble.mlir b/xla/python/ifrt/ir/tests/verify_disassemble.mlir index e36946470c23bc..09bdf055638af4 100644 --- a/xla/python/ifrt/ir/tests/verify_disassemble.mlir +++ b/xla/python/ifrt/ir/tests/verify_disassemble.mlir @@ -1,4 +1,4 @@ -// RUN: ifrt-opt %s -split-input-file -verify-diagnostics +// RUN: ifrt-opt %s -- -split-input-file -verify-diagnostics func.func @good_disassemble( %arg0: !ifrt.array, diff --git a/xla/python/ifrt/ir/tests/verify_loaded_executable.mlir b/xla/python/ifrt/ir/tests/verify_loaded_executable.mlir index 23d6d9759ff42a..ab58111e96c3d9 100644 --- a/xla/python/ifrt/ir/tests/verify_loaded_executable.mlir +++ b/xla/python/ifrt/ir/tests/verify_loaded_executable.mlir @@ -1,4 +1,4 @@ -// RUN: ifrt-opt %s -split-input-file -verify-diagnostics +// RUN: ifrt-opt %s -- -split-input-file -verify-diagnostics ifrt.LoadedExecutable @good on devices [0,1] : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, diff --git a/xla/python/ifrt/ir/tests/verify_remap_arrays.mlir b/xla/python/ifrt/ir/tests/verify_remap_arrays.mlir index b7c4db63d753a5..8ea3ab9ab3cd06 100644 --- a/xla/python/ifrt/ir/tests/verify_remap_arrays.mlir +++ b/xla/python/ifrt/ir/tests/verify_remap_arrays.mlir @@ -1,4 +1,4 @@ -// RUN: ifrt-opt %s -split-input-file -verify-diagnostics +// RUN: ifrt-opt %s -- -split-input-file -verify-diagnostics !array0 = !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 1>, [0]> diff --git a/xla/python/ifrt/ir/tests/verify_reshard.mlir b/xla/python/ifrt/ir/tests/verify_reshard.mlir index 717e5aa747b518..710144c34ac218 100644 --- a/xla/python/ifrt/ir/tests/verify_reshard.mlir +++ b/xla/python/ifrt/ir/tests/verify_reshard.mlir @@ -1,4 +1,4 @@ -// RUN: ifrt-opt %s -split-input-file -verify-diagnostics +// RUN: ifrt-opt %s -- -split-input-file -verify-diagnostics func.func @good_reshard( %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, diff --git a/xla/python/ifrt/ir/tests/vifrt/ifrt_legalize_to_vifrt.0_1_0.mlir b/xla/python/ifrt/ir/tests/vifrt/ifrt_legalize_to_vifrt.0_1_0.mlir index 8a0d319746a624..1c2f45343f3df5 100644 --- a/xla/python/ifrt/ir/tests/vifrt/ifrt_legalize_to_vifrt.0_1_0.mlir +++ b/xla/python/ifrt/ir/tests/vifrt/ifrt_legalize_to_vifrt.0_1_0.mlir @@ -1,4 +1,4 @@ -// RUN: ifrt-opt %s --ifrt-legalize-to-vifrt --symbol-dce --mlir-print-op-generic -split-input-file | FileCheck %s +// RUN: ifrt-opt %s -- --ifrt-legalize-to-vifrt --symbol-dce --mlir-print-op-generic -split-input-file | FileCheck %s // RUN: ifrt-translate --serialize --ifrt_version=0.1.0 --atom_program_version=1.8.0 --strip_debuginfo %s | ifrt-translate --deserialize --strip_debuginfo | ifrt-opt > %t.0 // RUN: ifrt-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/xla/python/ifrt/ir/tests/vifrt/ifrt_legalize_to_vifrt.mlir b/xla/python/ifrt/ir/tests/vifrt/ifrt_legalize_to_vifrt.mlir index 50f8614469820f..04d0695b3548a8 100644 --- a/xla/python/ifrt/ir/tests/vifrt/ifrt_legalize_to_vifrt.mlir +++ b/xla/python/ifrt/ir/tests/vifrt/ifrt_legalize_to_vifrt.mlir @@ -1,4 +1,4 @@ -// RUN: ifrt-opt %s --ifrt-legalize-to-vifrt --symbol-dce --mlir-print-op-generic -split-input-file | FileCheck %s +// RUN: ifrt-opt %s -- --ifrt-legalize-to-vifrt --symbol-dce --mlir-print-op-generic -split-input-file | FileCheck %s // RUN: ifrt-translate --serialize --ifrt_version=current --atom_program_version=current %s | ifrt-translate --deserialize | ifrt-opt > %t.0 // RUN: ifrt-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/xla/python/ifrt/ir/tests/vifrt/ifrt_legalize_to_vifrt_invalid.mlir b/xla/python/ifrt/ir/tests/vifrt/ifrt_legalize_to_vifrt_invalid.mlir index eb73f12d1b2f42..3e4eacbd11393f 100644 --- a/xla/python/ifrt/ir/tests/vifrt/ifrt_legalize_to_vifrt_invalid.mlir +++ b/xla/python/ifrt/ir/tests/vifrt/ifrt_legalize_to_vifrt_invalid.mlir @@ -1,4 +1,4 @@ -// RUN: ifrt-opt --ifrt-legalize-to-vifrt --symbol-dce --split-input-file -verify-diagnostics %s +// RUN: ifrt-opt -- --ifrt-legalize-to-vifrt --symbol-dce --split-input-file -verify-diagnostics %s !array_t0 = !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0,1]> diff --git a/xla/python/ifrt/ir/transforms/BUILD b/xla/python/ifrt/ir/transforms/BUILD index c7bf2bd0c2702a..27a293dbda7479 100644 --- a/xla/python/ifrt/ir/transforms/BUILD +++ b/xla/python/ifrt/ir/transforms/BUILD @@ -33,6 +33,7 @@ cc_library( "ifrt_atom_programs_to_vhlo_pass.cc", "ifrt_compile_and_propagate_shardings_pass.cc", "ifrt_compile_atom_program_pass.cc", + "ifrt_dump_atom_programs_pass.cc", "ifrt_duplicated_callee_elimination_pass.cc", "ifrt_legalize_to_vifrt_pass.cc", "ifrt_lower_atom_program_metadata_to_xla_pass.cc", @@ -87,6 +88,7 @@ cc_library( "//xla/service:hlo_proto_cc", "//xla/service/spmd/shardy:constants", "//xla/service/spmd/shardy:utils", + "//xla/tsl/platform:env", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", @@ -114,6 +116,7 @@ cc_library( "@stablehlo//:stablehlo_serialization", "@tsl//tsl/platform:env", "@tsl//tsl/platform:fingerprint", + "@tsl//tsl/platform:path", "@tsl//tsl/platform:protobuf", "@tsl//tsl/platform:statusor", ], diff --git a/xla/python/ifrt/ir/transforms/ifrt_dump_atom_programs_pass.cc b/xla/python/ifrt/ir/transforms/ifrt_dump_atom_programs_pass.cc new file mode 100644 index 00000000000000..b49b3f1ec5304c --- /dev/null +++ b/xla/python/ifrt/ir/transforms/ifrt_dump_atom_programs_pass.cc @@ -0,0 +1,117 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/strings/str_cat.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Pass/Pass.h" +#include "xla/python/ifrt/ir/ifrt_ops.h" +#include "xla/python/ifrt/ir/transforms/passes.h" +#include "xla/python/ifrt/ir/transforms/utils.h" +#include "xla/tsl/platform/env.h" +#include "tsl/platform/path.h" + +namespace xla { +namespace ifrt { + +namespace { + +#define GEN_PASS_DEF_IFRTDUMPATOMPROGRAMSPASS +#include "xla/python/ifrt/ir/transforms/passes.h.inc" + +absl::Status DumpOperation(mlir::Operation* op, std::string dump_dir, + std::string filename) { + std::string file_path = + tsl::io::JoinPath(dump_dir, absl::StrCat(filename, ".mlir")); + return tsl::WriteStringToFile(tsl::Env::Default(), file_path, + OperationToString(op, mlir::OpPrintingFlags())); +} + +class IfrtDumpAtomProgramsPass + : public impl::IfrtDumpAtomProgramsPassBase { + public: + using impl::IfrtDumpAtomProgramsPassBase< + IfrtDumpAtomProgramsPass>::IfrtDumpAtomProgramsPassBase; + + void runOnOperation() override { + if (dump_dir.empty()) { + return signalPassFailure(); + } + + mlir::SymbolTableCollection symbol_table; + mlir::ModuleOp module_op = getOperation(); + // Keeps track of the atom programs that have already been dumped. + absl::flat_hash_set dumped_atom_program_names; + + auto main_func = GetMainFunction(module_op); + + // Clones the main function to ensure that the attribute aliases are + // preserved while printing. Otherwise, the op would be printed in its + // full form (i.e., every argument with the entire device list expanded) + // and would lead to large ifrt dump files. + auto cloned_main = main_func.clone(); + if (auto status = DumpOperation(cloned_main, dump_dir, "ifrt_main_func"); + !status.ok()) { + cloned_main.erase(); + main_func->emitOpError() + << "failed to dump main func: " << status.ToString(); + signalPassFailure(); + return; + } + cloned_main.erase(); + + mlir::WalkResult result = + main_func.walk([&](CallOp call_op) -> mlir::WalkResult { + mlir::func::FuncOp callee = call_op.getCalleeOp(symbol_table); + CHECK(callee != nullptr); + auto atom_program_module = + llvm::cast(callee->getParentOp()); + std::string atom_program_name = + atom_program_module.getSymNameAttr().str(); + if (dumped_atom_program_names.insert(atom_program_name).second) { + if (auto status = DumpOperation(atom_program_module, dump_dir, + atom_program_name); + !status.ok()) { + return call_op->emitOpError() + << "failed to dump atom program: " << status.ToString(); + } + } + return mlir::WalkResult::advance(); + }); + if (result.wasInterrupted()) { + signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr> +CreateIfrtDumpAtomProgramsPass(IfrtDumpAtomProgramsPassOptions options) { + return std::make_unique(options); +} + +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt/ir/transforms/passes.h b/xla/python/ifrt/ir/transforms/passes.h index 4b6263f55bb05a..2882c4782fb59d 100644 --- a/xla/python/ifrt/ir/transforms/passes.h +++ b/xla/python/ifrt/ir/transforms/passes.h @@ -129,6 +129,9 @@ std::unique_ptr> CreateIfrtVerifyDeviceTypeConsistencyPass( IfrtVerifyDeviceTypeConsistencyPassOptions options = {}); +std::unique_ptr> +CreateIfrtDumpAtomProgramsPass(IfrtDumpAtomProgramsPassOptions options = {}); + std::unique_ptr> CreateIfrtAtomProgramsToVhloPass( tsl::protobuf::RepeatedPtrField* atom_programs, diff --git a/xla/python/ifrt/ir/transforms/passes.td b/xla/python/ifrt/ir/transforms/passes.td index ab768522c6bbfa..1418469af17290 100644 --- a/xla/python/ifrt/ir/transforms/passes.td +++ b/xla/python/ifrt/ir/transforms/passes.td @@ -474,6 +474,16 @@ This pass fails if ]; } +def IfrtDumpAtomProgramsPass + : Pass<"ifrt-dump-atom-programs", "mlir::ModuleOp"> { + let summary = "Extracts atom programs from module and dumps them to files."; + let constructor = "CreateIfrtDumpAtomProgramsPass()"; + let options = [ + Option<"dump_dir", "dump_dir", "std::string", "", + "The directory to dump the atom programs and the main function to.">, + ]; +} + def IfrtLegalizeToVifrtPass : Pass<"ifrt-legalize-to-vifrt", "mlir::ModuleOp"> { let summary = "Legalize IFRT to VIFRT."; let dependentDialects = ["xla::ifrt::VifrtDialect"];