diff --git a/xla/python/ifrt/ir/tests/BUILD b/xla/python/ifrt/ir/tests/BUILD index 46b6fc103271ee..e749d43ff84ac5 100644 --- a/xla/python/ifrt/ir/tests/BUILD +++ b/xla/python/ifrt/ir/tests/BUILD @@ -32,6 +32,7 @@ xla_cc_binary( testonly = True, srcs = ["ifrt-opt.cc"], deps = [ + "//tensorflow/compiler/mlir:init_mlir", "//xla/mlir_hlo:hlo_dialect_registration", "//xla/pjrt:pjrt_executable", "//xla/python/ifrt", diff --git a/xla/python/ifrt/ir/tests/ifrt-opt.cc b/xla/python/ifrt/ir/tests/ifrt-opt.cc index 593e9737c5812d..d19f6560538c18 100644 --- a/xla/python/ifrt/ir/tests/ifrt-opt.cc +++ b/xla/python/ifrt/ir/tests/ifrt-opt.cc @@ -28,6 +28,7 @@ limitations under the License. #include "mlir/IR/DialectRegistry.h" #include "mlir/InitAllPasses.h" #include "mlir/Tools/mlir-opt/MlirOptMain.h" +#include "third_party/tensorflow/compiler/mlir/init_mlir.h" #include "xla/mlir_hlo/mhlo/IR/register.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/python/ifrt/dtype.h" @@ -117,6 +118,8 @@ class TestChildExecutableCompiler : public AtomProgramCompiler { } // namespace xla int main(int argc, char** argv) { + tensorflow::InitMlir y(&argc, &argv); + std::shared_ptr compiler = std::make_shared(); auto compile_options = std::make_shared +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/strings/str_cat.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Pass/Pass.h" +#include "xla/python/ifrt/ir/ifrt_ops.h" +#include "xla/python/ifrt/ir/transforms/passes.h" +#include "xla/python/ifrt/ir/transforms/utils.h" +#include "xla/tsl/platform/env.h" +#include "tsl/platform/path.h" + +namespace xla { +namespace ifrt { + +namespace { + +#define GEN_PASS_DEF_IFRTDUMPATOMPROGRAMSPASS +#include "xla/python/ifrt/ir/transforms/passes.h.inc" + +absl::Status DumpOperation(mlir::Operation* op, std::string dump_dir, + std::string filename) { + std::string file_path = + tsl::io::JoinPath(dump_dir, absl::StrCat(filename, ".mlir")); + return tsl::WriteStringToFile(tsl::Env::Default(), file_path, + OperationToString(op, mlir::OpPrintingFlags())); +} + +class IfrtDumpAtomProgramsPass + : public impl::IfrtDumpAtomProgramsPassBase { + public: + using impl::IfrtDumpAtomProgramsPassBase< + IfrtDumpAtomProgramsPass>::IfrtDumpAtomProgramsPassBase; + + void runOnOperation() override { + if (dump_dir.empty()) { + return signalPassFailure(); + } + + mlir::SymbolTableCollection symbol_table; + mlir::ModuleOp module_op = getOperation(); + // Keeps track of the atom programs that have already been dumped. + absl::flat_hash_set dumped_atom_program_names; + + auto main_func = GetMainFunction(module_op); + + // Clones the main function to ensure that the attribute aliases are + // preserved while printing. Otherwise, the op would be printed in its + // full form (i.e., every argument with the entire device list expanded) + // and would lead to large ifrt dump files. + auto cloned_main = main_func.clone(); + if (auto status = DumpOperation(cloned_main, dump_dir, "ifrt_main_func"); + !status.ok()) { + cloned_main.erase(); + main_func->emitOpError() + << "failed to dump main func: " << status.ToString(); + signalPassFailure(); + return; + } + cloned_main.erase(); + + mlir::WalkResult result = + main_func.walk([&](CallOp call_op) -> mlir::WalkResult { + mlir::func::FuncOp callee = call_op.getCalleeOp(symbol_table); + CHECK(callee != nullptr); + auto atom_program_module = + llvm::cast(callee->getParentOp()); + std::string atom_program_name = + atom_program_module.getSymNameAttr().str(); + if (dumped_atom_program_names.insert(atom_program_name).second) { + if (auto status = DumpOperation(atom_program_module, dump_dir, + atom_program_name); + !status.ok()) { + return call_op->emitOpError() + << "failed to dump atom program: " << status.ToString(); + } + } + return mlir::WalkResult::advance(); + }); + if (result.wasInterrupted()) { + signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr> +CreateIfrtDumpAtomProgramsPass(IfrtDumpAtomProgramsPassOptions options) { + return std::make_unique(options); +} + +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt/ir/transforms/passes.h b/xla/python/ifrt/ir/transforms/passes.h index 4b6263f55bb05a..2882c4782fb59d 100644 --- a/xla/python/ifrt/ir/transforms/passes.h +++ b/xla/python/ifrt/ir/transforms/passes.h @@ -129,6 +129,9 @@ std::unique_ptr> CreateIfrtVerifyDeviceTypeConsistencyPass( IfrtVerifyDeviceTypeConsistencyPassOptions options = {}); +std::unique_ptr> +CreateIfrtDumpAtomProgramsPass(IfrtDumpAtomProgramsPassOptions options = {}); + std::unique_ptr> CreateIfrtAtomProgramsToVhloPass( tsl::protobuf::RepeatedPtrField* atom_programs, diff --git a/xla/python/ifrt/ir/transforms/passes.td b/xla/python/ifrt/ir/transforms/passes.td index ab768522c6bbfa..1418469af17290 100644 --- a/xla/python/ifrt/ir/transforms/passes.td +++ b/xla/python/ifrt/ir/transforms/passes.td @@ -474,6 +474,16 @@ This pass fails if ]; } +def IfrtDumpAtomProgramsPass + : Pass<"ifrt-dump-atom-programs", "mlir::ModuleOp"> { + let summary = "Extracts atom programs from module and dumps them to files."; + let constructor = "CreateIfrtDumpAtomProgramsPass()"; + let options = [ + Option<"dump_dir", "dump_dir", "std::string", "", + "The directory to dump the atom programs and the main function to.">, + ]; +} + def IfrtLegalizeToVifrtPass : Pass<"ifrt-legalize-to-vifrt", "mlir::ModuleOp"> { let summary = "Legalize IFRT to VIFRT."; let dependentDialects = ["xla::ifrt::VifrtDialect"];