From 45dd31f7c01db2836e99229f0b83fce01e4cd5af Mon Sep 17 00:00:00 2001 From: Will Froom Date: Thu, 19 Dec 2024 01:34:16 -0800 Subject: [PATCH 1/6] [XLA:CPU] Emit nested computations prior to calling ElementalIrEmitter PiperOrigin-RevId: 707827247 --- xla/service/cpu/ir_emitter.cc | 6 ++ xla/service/cpu/ir_emitter.h | 4 +- xla/service/cpu/ir_emitter2.cc | 101 +++++++++++++++++++-------------- xla/service/cpu/ir_emitter2.h | 4 ++ 4 files changed, 71 insertions(+), 44 deletions(-) diff --git a/xla/service/cpu/ir_emitter.cc b/xla/service/cpu/ir_emitter.cc index a2498bb8b6e63a..bfafea513a3d69 100644 --- a/xla/service/cpu/ir_emitter.cc +++ b/xla/service/cpu/ir_emitter.cc @@ -224,7 +224,13 @@ absl::StatusOr IrEmitter::EmitComputation( std::string function_name = name_uniquer_.GetUniqueName(function_name_prefix); VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix << "]"; is_top_level_computation_ = is_top_level_computation; + + auto cleanup = absl::MakeCleanup( + [saved_allow_reassociation = allow_reassociation_, this]() { + allow_reassociation_ = saved_allow_reassociation; + }); allow_reassociation_ = allow_reassociation; + num_dynamic_loop_bounds_ = 0; auto backend_config_or = computation->root_instruction()->backend_config(); diff --git a/xla/service/cpu/ir_emitter.h b/xla/service/cpu/ir_emitter.h index e56a57ff97789f..926f6b6461ba37 100644 --- a/xla/service/cpu/ir_emitter.h +++ b/xla/service/cpu/ir_emitter.h @@ -637,7 +637,9 @@ class IrEmitter : public DfsHloVisitorWithDefault, llvm::IRBuilderBase* current_builder_; std::stack compute_function_; mlir::MLIRContext* mlir_context_; - bool allow_reassociation_; + // The state of allow_reassociation_ is required so that that it is + // transitive to all nested computations. + bool allow_reassociation_ = false; // The buffer allocation slice for the root of the computation being compiled. // Only relevant for thread local computations. diff --git a/xla/service/cpu/ir_emitter2.cc b/xla/service/cpu/ir_emitter2.cc index ea63cb5a44a045..621fffbdfa3329 100644 --- a/xla/service/cpu/ir_emitter2.cc +++ b/xla/service/cpu/ir_emitter2.cc @@ -99,10 +99,8 @@ KernelApiIrBuilder::Options KernelApiIrBuilderOptionsFromHloModuleConfig( class IrEmitter2::ElementalIrEmitter : public CpuElementalIrEmitter { public: ElementalIrEmitter(llvm::Module* module, llvm::IRBuilderBase* b, - const HloModule* hlo_module, IrEmitter* nested_ir_emitter, - bool fast_min_max) + IrEmitter* nested_ir_emitter, bool fast_min_max) : CpuElementalIrEmitter(module, b, true, fast_min_max), - hlo_module_(hlo_module), nested_ir_emitter_(nested_ir_emitter), fast_min_max_(fast_min_max) {} @@ -110,43 +108,8 @@ class IrEmitter2::ElementalIrEmitter : public CpuElementalIrEmitter { absl::StatusOr> EmitThreadLocalCall( const HloComputation& callee, absl::Span parameters, absl::string_view name, bool is_reducer) override { - // Module must be scheduled to emit thread local computation. - if (!hlo_module_ || !hlo_module_->has_schedule()) { - return absl::InternalError( - "HLO module must be scheduled to emit thread local computation."); - } - - // Create a nested function for thread local computation(s) if it is not - // already created. Nested functions are created with internal linkage. - auto emit_computation = [&](const HloComputation* computation) { - if (!nested_ir_emitter_->is_computation_emitted(*computation, - is_reducer)) { - VLOG(2) << "Emit nested computation: " << computation->name(); - TF_RETURN_IF_ERROR( - nested_ir_emitter_ - ->EmitComputation( - const_cast(computation), name, false, - hlo_module_->schedule() - .sequence(computation) - .instructions(), - /*allow_reassociation=*/is_reducer, - /*function_attributes=*/{llvm::Attribute::AlwaysInline}) - .status()); - } - return absl::OkStatus(); - }; - - // We emit all embedded computations reachable through the `callee` to - // support nested thread local call, i.e., nested map computations. - for (HloComputation* embedded : callee.MakeEmbeddedComputationsList()) { - if (embedded->IsFusionComputation()) continue; - TF_RETURN_IF_ERROR(emit_computation(embedded)); - } - TF_RETURN_IF_ERROR(emit_computation(&callee)); - // Add a thread local call to the nested computation. VLOG(2) << "Emit thread local call to: " << callee.name(); - nested_ir_emitter_->b()->SetInsertPoint(b()->GetInsertPoint()); auto values = nested_ir_emitter_->EmitThreadLocalCall( callee, parameters, name, is_reducer, /*in_compute_function=*/false); @@ -156,7 +119,6 @@ class IrEmitter2::ElementalIrEmitter : public CpuElementalIrEmitter { bool fast_min_max() override { return fast_min_max_; } private: - const HloModule* hlo_module_; IrEmitter* nested_ir_emitter_; bool fast_min_max_; }; @@ -195,6 +157,8 @@ absl::StatusOr IrEmitter2::EmitElementalHostKernel( llvm::IRBuilder<> b(module_->getContext()); b.SetInsertPoint(kernel_prototype.function->getEntryBlock().getTerminator()); + IrEmitter::IRBuilderGuard builder_guard = nested_ir_emitter_->WithBuilder(b); + ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; for (int64_t i = 0; i < instr->operand_count(); ++i) { const HloInstruction* operand = instr->operand(i); @@ -203,8 +167,16 @@ absl::StatusOr IrEmitter2::EmitElementalHostKernel( }; } - ElementalIrEmitter elemental_emitter(module_, &b, &hlo_module_, - nested_ir_emitter_, fast_min_max()); + if (instr->has_to_apply()) { + HloComputation* nested_computation = instr->to_apply(); + bool is_reducer = instr->opcode() == HloOpcode::kReduce || + instr->opcode() == HloOpcode::kReduceWindow; + TF_RETURN_IF_ERROR(EmitNestedComputation( + *nested_computation, llvm_ir::IrName(instr), is_reducer)); + } + + ElementalIrEmitter elemental_emitter(module_, &b, nested_ir_emitter_, + fast_min_max()); llvm_ir::ElementGenerator element_generator = elemental_emitter.MakeElementGenerator(instr, operand_to_generator); @@ -266,8 +238,14 @@ absl::StatusOr IrEmitter2::EmitFusionHostKernel( llvm::IRBuilder<> b(module_->getContext()); b.SetInsertPoint(kernel_prototype.function->getEntryBlock().getTerminator()); - ElementalIrEmitter elemental_emitter(module_, &b, &hlo_module_, - nested_ir_emitter_, fast_min_max()); + IrEmitter::IRBuilderGuard builder_guard = nested_ir_emitter_->WithBuilder(b); + + HloComputation* nested_computation = fusion->fused_instructions_computation(); + TF_RETURN_IF_ERROR(EmitNestedComputation(*nested_computation, + llvm_ir::IrName(fusion), false)); + + ElementalIrEmitter elemental_emitter(module_, &b, nested_ir_emitter_, + fast_min_max()); FusedIrEmitter fused_emitter(elemental_emitter); for (int i = 0; i < fusion->operand_count(); i++) { @@ -911,6 +889,43 @@ absl::StatusOr IrEmitter2::EmitElementalLoops( return se::ThreadDim(); } +absl::Status IrEmitter2::EmitNestedComputation(const HloComputation& callee, + absl::string_view name, + bool is_reducer) { + // Module must be scheduled to emit thread local computation. + if (!hlo_module_.has_schedule()) { + return absl::InternalError( + "HLO module must be scheduled to emit thread local computation."); + } + + if (nested_ir_emitter_->is_computation_emitted(callee, is_reducer)) { + return absl::OkStatus(); + } + + for (HloInstruction* instr : callee.instructions()) { + bool nested_is_reducer = instr->opcode() == HloOpcode::kReduce || + instr->opcode() == HloOpcode::kReduceWindow; + for (HloComputation* called_computation : instr->called_computations()) { + // reassociation is transitive so we "or" the caller and the callee. + TF_RETURN_IF_ERROR( + EmitNestedComputation(*called_computation, llvm_ir::IrName(instr), + is_reducer || nested_is_reducer)); + } + } + + if (callee.IsFusionComputation()) { + return absl::OkStatus(); + } + + VLOG(2) << "Emit nested computation: " << callee.name(); + return nested_ir_emitter_ + ->EmitComputation(const_cast(&callee), name, false, + hlo_module_.schedule().sequence(&callee).instructions(), + /*allow_reassociation=*/is_reducer, + /*function_attributes=*/{llvm::Attribute::AlwaysInline}) + .status(); +} + // This is a convenience function taken from IrEmitter, it uses module_ class // field. If there will be more functions that use module_, we should consider // refactoring (like we did for compute_function_ and builder_). diff --git a/xla/service/cpu/ir_emitter2.h b/xla/service/cpu/ir_emitter2.h index eafaa99e123006..be7048414de2b0 100644 --- a/xla/service/cpu/ir_emitter2.h +++ b/xla/service/cpu/ir_emitter2.h @@ -25,6 +25,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/IRBuilder.h" @@ -228,6 +229,9 @@ class IrEmitter2 { const KernelPrototype& kernel_prototype, const llvm_ir::ElementGenerator& element_generator); + absl::Status EmitNestedComputation(const HloComputation& callee, + absl::string_view name, bool is_reducer); + bool fast_min_max() const; // Returns the number of bytes within the shape. From 5ed9f2fe1fe421f53e747449d956c41bfeb60298 Mon Sep 17 00:00:00 2001 From: Shu Wang Date: Thu, 19 Dec 2024 02:16:28 -0800 Subject: [PATCH 2/6] PR #20635: Remove workspace size for SDPA FP8 custom-call tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Imported from GitHub PR https://github.com/openxla/xla/pull/20635 Related to #20564 where only one of the commits is merged. @hawkinsp The 2 tests affected by the workspace size on blackwell are FlashAttentionBMMScaleSoftmaxBMMF8.Flash_Attention_Inference_BMM1_NoMask_Softmax_BMM2_BNTH_F8 and FlashAttentionBMMScaleSoftmaxBMMF8.Flash_Attention_Inference_BMM1_NoMask_Softmax_BMM2_BNTH_F8. On Blackwell, the required workspace size is 0 as oppose to 16 on Hopper. Removing hardcoded workspace size to have cuDNN compiler handle it automatically. Copybara import of the project: -- 83153e7cd138f9ef3619ba38b5760e644c62037b by “wenscarl” : Remove hard-coded workspace size for FP8 SPDA tests Merging this change closes #20635 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/20635 from wenscarl:spda_fp8_custom_call_workspace 83153e7cd138f9ef3619ba38b5760e644c62037b PiperOrigin-RevId: 707837211 --- xla/service/gpu/tests/gpu_fused_mha_test.cc | 12 ++++-------- .../gpu/transforms/cudnn_custom_call_compiler.cc | 3 +++ 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/xla/service/gpu/tests/gpu_fused_mha_test.cc b/xla/service/gpu/tests/gpu_fused_mha_test.cc index e8d8a04f1a93ec..33214758e230fd 100644 --- a/xla/service/gpu/tests/gpu_fused_mha_test.cc +++ b/xla/service/gpu/tests/gpu_fused_mha_test.cc @@ -1471,8 +1471,7 @@ XLA_TEST_F(FlashAttentionBMMScaleSoftmaxBMMF8, XlaBuilder builder(TestName()); std::string ref_bnth = R"( custom-call.4.0 = ( - bf16[4,4,16,16]{3,1,2,0}, - u8[0]{0} + bf16[4,4,16,16]{3,1,2,0} ) custom-call( convert.19, convert.31, @@ -1546,8 +1545,7 @@ XLA_TEST_F(FlashAttentionBMMScaleSoftmaxBMMF8, custom-call.21.0 = ( f8e4m3fn[4,4,16,16]{3,1,2,0}, f32[1,1,1,1]{3,2,1,0}, - f32[1,1,1,1]{3,2,1,0}, - u8[16]{0} + f32[1,1,1,1]{3,2,1,0} ) custom-call( convert.18, convert.30, @@ -1652,8 +1650,7 @@ XLA_TEST_F(FlashAttentionBMMScaleSoftmaxBMMF8, std::string ref_btnh = R"( custom-call.4.0 = ( - bf16[4,16,4,16]{3,2,1,0}, - u8[0]{0} + bf16[4,16,4,16]{3,2,1,0} ) custom-call( convert.19, convert.31, @@ -1726,8 +1723,7 @@ XLA_TEST_F(FlashAttentionBMMScaleSoftmaxBMMF8, custom-call.21.0 = ( f8e4m3fn[4,16,4,16]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, - f32[1,1,1,1]{3,2,1,0}, - u8[16]{0} + f32[1,1,1,1]{3,2,1,0} ) custom-call( convert.18, convert.30, diff --git a/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc b/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc index b711f3142f3328..0dc92c47d2cb55 100644 --- a/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc +++ b/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc @@ -393,6 +393,9 @@ class CuDnnCustomCallVisitor : public DfsHloRewriteVisitor { : dnn_support_(dnn_support), compilation_results_(compilation_results) {} void AddWorkspace(HloInstruction &hlo, int64_t workspace_size) { + if (workspace_size == 0) { + return; + } VLOG(4) << "Applying workspace size " << workspace_size << " to " << hlo.ToString(); Shape *shape = hlo.mutable_shape(); From 45597a0948298cf34ae9095ffd71c1ddf37b3a95 Mon Sep 17 00:00:00 2001 From: Ilya Tikhonovskiy Date: Thu, 19 Dec 2024 04:28:59 -0800 Subject: [PATCH 3/6] [XLA:GPU] Implement i4 support as a Triton IR rewrite. Right now the triton emitter for multiply emits the code that operates with the i4 tensors packed into i8 with one 2x smaller dimension together with the unpacking steps. It makes sense to rework this taking into the account the fact that we also want to replace these emitters with the very complicated tailing logic with the new triton emitters. The emitter could generate the code that operates with i4 tensors as is. I.e. emit the ops with AxBxi4 tensors and use ExtSI when we need to get i8. That would make the emitter simpler. After that we could do a Triton IR rewrite pass that would convert these i4 ops to i4 packed into i8 ops, and replace ExtSI to the unpacking sequence. The cl is the example of such rewriter that covers the case with i4 tiles packed along the major dim. PiperOrigin-RevId: 707867342 --- xla/service/gpu/fusions/triton/BUILD | 5 + .../triton/compilation_pipeline_cuda.cc | 2 + .../fusions/triton/xla_triton_int4_passes.cc | 324 ++++++++++++++++++ .../gpu/fusions/triton/xla_triton_passes.h | 1 + .../gpu/fusions/triton/xla_triton_passes.td | 11 + .../gpu/tests/int4_to_packed_int4.mlir | 110 ++++++ .../gpu/tests/int4_to_packed_int4_small.mlir | 12 + 7 files changed, 465 insertions(+) create mode 100644 xla/service/gpu/fusions/triton/xla_triton_int4_passes.cc create mode 100644 xla/service/gpu/tests/int4_to_packed_int4.mlir create mode 100644 xla/service/gpu/tests/int4_to_packed_int4_small.mlir diff --git a/xla/service/gpu/fusions/triton/BUILD b/xla/service/gpu/fusions/triton/BUILD index 8ed220d5d3e192..a0ae574269ca37 100644 --- a/xla/service/gpu/fusions/triton/BUILD +++ b/xla/service/gpu/fusions/triton/BUILD @@ -374,6 +374,7 @@ gentbl_cc_library( cc_library( name = "xla_triton_passes", srcs = [ + "xla_triton_int4_passes.cc", "xla_triton_prevent_mmav3_loop_unrolling_pass.cc", "xla_triton_sparse_passes.cc", ], @@ -383,9 +384,12 @@ cc_library( deps = [ ":xla_triton", ":xla_triton_passes_inc_gen", + "//xla/service/llvm_ir:llvm_util", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:GPUToNVVMTransforms", "@llvm-project//mlir:IR", @@ -393,6 +397,7 @@ cc_library( "@llvm-project//mlir:NVVMDialect", "@llvm-project//mlir:Pass", "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SCFTransforms", "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@triton//:TritonAnalysis", diff --git a/xla/service/gpu/fusions/triton/compilation_pipeline_cuda.cc b/xla/service/gpu/fusions/triton/compilation_pipeline_cuda.cc index 6bd49df697a7d9..2ce0a8039309b4 100644 --- a/xla/service/gpu/fusions/triton/compilation_pipeline_cuda.cc +++ b/xla/service/gpu/fusions/triton/compilation_pipeline_cuda.cc @@ -48,6 +48,8 @@ absl::Status CreateTritonPipeline( const int ccAsInt = cc.major * 10 + cc.minor; const int threadsPerWarp = 32; + pm->addPass(mt_xla::CreateInt4ToPackedInt4RewritePass()); + // Based on make_ttir() in // @triton//:third_party/nvidia/backend/compiler.py pm->addPass(mlir::createInlinerPass()); diff --git a/xla/service/gpu/fusions/triton/xla_triton_int4_passes.cc b/xla/service/gpu/fusions/triton/xla_triton_int4_passes.cc new file mode 100644 index 00000000000000..091970f645ee5d --- /dev/null +++ b/xla/service/gpu/fusions/triton/xla_triton_int4_passes.cc @@ -0,0 +1,324 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "absl/log/log.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" +#include "xla/service/llvm_ir/llvm_util.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" + +namespace mlir::triton::xla { + +using ::xla::llvm_ir::DumpToString; + +namespace mt = ::mlir::triton; +namespace ma = ::mlir::arith; + +#define GEN_PASS_DEF_LOADINT4REWRITEPASS +#include "xla/service/gpu/fusions/triton/xla_triton_passes.h.inc" + +class I4ToI8Converter : public TypeConverter { + public: + static Type convertIntegerType(IntegerType type) { + VLOG(10) << "I4ToI8Converter: converting IntegerType for " + << DumpToString(type); + if (type.getWidth() == 4) { + auto new_type = IntegerType::get(type.getContext(), 8); + VLOG(10) << " -> I4ToI8Converter: IntegerType converted to " + << DumpToString(new_type); + return new_type; + } + return type; + } + static Type convertRankedTensorType(RankedTensorType type) { + VLOG(10) << "I4ToI8Converter: RankedTensorType for " << DumpToString(type); + if (!type.getElementType().isInteger(4)) return type; + + auto shape = type.getShape(); + if (shape[0] == ShapedType::kDynamic) + return type; // Only handle static shapes for simplicity + + std::vector newShape(shape.begin(), shape.end()); + newShape[0] /= 2; + auto new_type = + RankedTensorType::get(newShape, IntegerType::get(type.getContext(), 8)); + VLOG(10) << " -> I4ToI8Converter: RankedTensorType converted to " + << DumpToString(new_type); + return new_type; + } + + PointerType convertPointerType(PointerType ptr_type) { + VLOG(10) << "I4ToI8Converter: converting PointerType for " + << DumpToString(ptr_type); + auto pointee_type = ptr_type.getPointeeType(); + auto new_pointee_type = convertType(pointee_type); + auto new_ptr_type = + PointerType::get(new_pointee_type, ptr_type.getAddressSpace()); + VLOG(10) << " -> I4ToI8Converter: converted PointerType to " + << DumpToString(new_ptr_type); + return new_ptr_type; + } + Type convertFunctionType(FunctionType func_type) { + VLOG(10) << "I4ToI8Converter: converting FunctionType " + << DumpToString(func_type); + + SmallVector inputs; + if (failed(convertTypes(func_type.getInputs(), inputs))) return func_type; + + SmallVector results; + if (failed(convertTypes(func_type.getResults(), results))) return func_type; + + auto new_func_type = + FunctionType::get(func_type.getContext(), inputs, results); + VLOG(10) << " -> I4ToI8Converter: converted FunctionType to " + << DumpToString(new_func_type); + return new_func_type; + } + + I4ToI8Converter() { + // Passthrough for other types. + addConversion([](Type type) { + VLOG(10) << "I4ToI8Converter: passthrough for " << DumpToString(type); + return type; + }); + + // Convert i4 to i8 + addConversion( + [this](IntegerType type) { return this->convertIntegerType(type); }); + + // Convert tensor to tensor + addConversion([this](RankedTensorType type) { + return this->convertRankedTensorType(type); + }); + + // Convert !tt.ptr> to !tt.ptr> + addConversion( + [this](PointerType type) { return this->convertPointerType(type); }); + + // Convert function type to function type + addConversion( + [this](FunctionType type) { return this->convertFunctionType(type); }); + } +}; + +class MakeTensorPtrOpConversionPattern + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + MakeTensorPtrOp op, + OpConversionPattern::OpAdaptor adaptor, + ConversionPatternRewriter &r) const override { + // Convert the tensor type using the TypeConverter + auto new_type = getTypeConverter()->convertType(op.getType()); + if (op.getType() == new_type) { + return r.notifyMatchFailure(op, "no conversion needed"); + } + + auto loc = op.getLoc(); + Value c2 = + r.create(loc, r.getIntegerAttr(r.getI64Type(), 2)); + SmallVector shape{adaptor.getShape().begin(), + adaptor.getShape().end()}; + // The packing dim is major and it should twice smaller. + shape[0] = r.create(loc, shape[0], c2); + + // The packing dim is major and the other stride should be half of the + // original one. + SmallVector new_strides = adaptor.getStrides(); + new_strides[1] = r.create(loc, new_strides[1], c2); + + r.replaceOpWithNewOp( + op, new_type, adaptor.getBase(), shape, new_strides, + adaptor.getOffsets(), adaptor.getOrderAttr()); + + return success(); + } +}; + +class AddPtrOpConversionPattern : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + AddPtrOp op, OpConversionPattern::OpAdaptor adaptor, + ConversionPatternRewriter &r) const override { + // Convert the tensor type using the TypeConverter + auto new_type = getTypeConverter()->convertType(op.getType()); + if (op.getType() == new_type) { + return r.notifyMatchFailure(op, "no conversion needed"); + } + + // The increment for the next stripe of tiles along K dimension should be + // twice smaller. + auto ptr = adaptor.getOperands()[0]; + auto offset = adaptor.getOperands()[1]; + auto offset_type = offset.getType(); + Value c2 = + r.create(op.getLoc(), r.getIntegerAttr(offset_type, 2)); + auto new_offset = + r.create(op.getLoc(), offset_type, offset, c2); + + r.replaceOpWithNewOp(op, new_type, ptr, new_offset); + + return success(); + } +}; + +template +class OpTypeConversionPattern : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + OpType op, typename OpConversionPattern::OpAdaptor adaptor, + ConversionPatternRewriter &r) const override { + VLOG(10) << "OpTypeConversionPattern: matching\n" + << DumpToString(static_cast(op.getOperation())); + // Convert the tensor type using the TypeConverter + auto new_type = + OpConversionPattern::getTypeConverter()->convertType( + op.getType()); + if (op.getType() == new_type) { + VLOG(10) << "OpTypeConversionPattern: no conversion needed for " + << DumpToString(op.getType()); + return r.notifyMatchFailure(op, "no conversion needed"); + } + + r.replaceOpWithNewOp(op, new_type, adaptor.getOperands(), + op->getAttrs()); + return success(); + } +}; + +// The pattern converts the ExtSIOp that converts i4 tensor to i8 tensor to the +// unpack sequence with ShLIOp, ShRSIOp, JoinOp, TransOp and ReshapeOp that does +// the same thing. +class ExtSIInt4ToInt8Pattern : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(ma::ExtSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &r) const override { + auto i4_tensor = cast(op.getType()); + const auto &operand_type = cast(op.getIn().getType()); + + auto i4_type = r.getI4Type(); + auto i8_type = r.getI8Type(); + + if (operand_type.getElementType() != i4_type) { + return r.notifyMatchFailure(op, "not i4 operand"); + } + + // Make a new i8 tensor with the shape that is half of the int4 tensor. + SmallVector result_shape(i4_tensor.getShape()); + result_shape[0] /= 2; + auto i8_tensor = RankedTensorType::get(result_shape, i8_type); + + auto loc = op.getLoc(); + + Value shift4_const = + r.create(loc, r.getIntegerAttr(i8_type, 4)); + Value shift4 = r.create(loc, i8_tensor, shift4_const); + Value shifted_lo = + r.create(loc, i8_tensor, adaptor.getIn(), shift4); + Value lo = r.create(loc, i8_tensor, shifted_lo, shift4); + Value hi = r.create(loc, i8_tensor, adaptor.getIn(), shift4); + Value hi_lo = r.create(loc, hi, lo); + auto trans_attr = r.getDenseI32ArrayAttr({0, 2, 1}); + + Value trans_hi_lo = r.create(loc, hi_lo, trans_attr); + + r.replaceOpWithNewOp(op, i4_tensor, trans_hi_lo, + /*allow_reorder=*/false); + return success(); + } +}; + +struct PlainInt4ToPackedInt4RewritePass + : public impl::LoadInt4RewritePassBase { + void runOnOperation() override { + auto *ctx = &getContext(); + auto module = getOperation(); + + ConversionTarget target(*ctx); + + VLOG(10) << "before TypeRewrite rewrite"; + { + I4ToI8Converter converter; + ConversionTarget target(*ctx); + target.markUnknownOpDynamicallyLegal([&](Operation *op) { + if (auto func_op = dyn_cast(op)) { + VLOG(10) << "check funcOp: " << DumpToString(func_op); + if (func_op.getFunctionType() != + converter.convertType(func_op.getFunctionType())) { + VLOG(10) << "funcOp not legal: " << DumpToString(func_op); + return false; + } + } + bool is_legal = converter.isLegal(op); + VLOG(10) << "is_legal: " << is_legal << " for " << DumpToString(op); + return is_legal; + }); + RewritePatternSet patterns(ctx); + scf::populateSCFStructuralTypeConversions(converter, patterns); + patterns.add(ctx); + patterns.add>(converter, ctx); + patterns.add>(converter, ctx); + patterns.add(converter, ctx); + patterns.add(converter, ctx); + populateFunctionOpInterfaceTypeConversionPattern(patterns, + converter); + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { + VLOG(10) << "failed to apply partial conversion"; + signalPassFailure(); + } + } + VLOG(10) << "after TypeRewrite Module: " << DumpToString(module); + } +}; + +// The pass converts the types like tensor to tensor in the +// Triton dialect and replaces the ExtSIOp with the unpack sequence that accepts +// twice smaller i8 tensor and convert it to the twice bigger i8 tensor where +// every i4 element uses i8 space. At the end the module accepts the tt.ptr +// to the packed i4 tensor, and unpacks it to the i8 tensor for the further +// processing. It expects that the i4 tensor is packed along the major +// dimension. +std::unique_ptr CreateInt4ToPackedInt4RewritePass() { + return std::make_unique(); +} + +} // namespace mlir::triton::xla diff --git a/xla/service/gpu/fusions/triton/xla_triton_passes.h b/xla/service/gpu/fusions/triton/xla_triton_passes.h index 10f5e684cb5516..67034fe1df1897 100644 --- a/xla/service/gpu/fusions/triton/xla_triton_passes.h +++ b/xla/service/gpu/fusions/triton/xla_triton_passes.h @@ -36,6 +36,7 @@ std::unique_ptr CreateSparseLocalLoadToLLVMPass(); std::unique_ptr CreateSparseDotOpToLLVMPass(); std::unique_ptr CreateSparseWGMMAOpToLLVMPass(); std::unique_ptr CreatePreventMmaV3LoopUnrollingPass(); +std::unique_ptr CreateInt4ToPackedInt4RewritePass(); // Returns true if the `op` contains an operation in it's regions that satisfies // the `fn`. diff --git a/xla/service/gpu/fusions/triton/xla_triton_passes.td b/xla/service/gpu/fusions/triton/xla_triton_passes.td index 49e003e392ed15..21db540475b390 100644 --- a/xla/service/gpu/fusions/triton/xla_triton_passes.td +++ b/xla/service/gpu/fusions/triton/xla_triton_passes.td @@ -95,4 +95,15 @@ def PreventMmaV3LoopUnrollingPass let constructor = "CreatePreventMmaV3LoopUnrollingPass()"; } +def LoadInt4RewritePass + : Pass<"int4-to-packed-int4-rewrite", "mlir::ModuleOp"> { + let summary = "Converts ops with int4 tensors to the ops with int4 packed to int8 tensors."; + let description = [{ + This pass replaces the int4 tensors with the int4 packed to int8 tensor of + the twice smaller size. It also replaces the plain ExtSIOp upcast to the + int8 tensor with the unpack sequence. + }]; + let constructor = "CreateInt4ToPackedInt4RewritePass()"; +} + #endif // XLA_SERVICE_GPU_FUSIONS_TRITON_XLA_TRITON_PASSES_TD_ diff --git a/xla/service/gpu/tests/int4_to_packed_int4.mlir b/xla/service/gpu/tests/int4_to_packed_int4.mlir new file mode 100644 index 00000000000000..29cdd45524d57c --- /dev/null +++ b/xla/service/gpu/tests/int4_to_packed_int4.mlir @@ -0,0 +1,110 @@ +// RUN: xla-opt --int4-to-packed-int4-rewrite %s --mlir-print-ir-after-all + +module { + tt.func @gemm_fusion_dot_2_impl(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32> + %0 = tt.get_program_id x : i32 + %c16_i32 = arith.constant 16 : i32 + %1 = arith.divsi %0, %c16_i32 : i32 + %c8_i32 = arith.constant 8 : i32 + %2 = arith.muli %1, %c8_i32 : i32 + %c1_i32 = arith.constant 1 : i32 + %3 = arith.subi %c1_i32, %2 : i32 + %4 = arith.cmpi slt, %3, %c8_i32 : i32 + %5 = arith.select %4, %3, %c8_i32 : i32 + %6 = arith.remsi %0, %5 : i32 + %7 = arith.addi %2, %6 : i32 + %c16_i32_0 = arith.constant 16 : i32 + %8 = arith.remsi %0, %c16_i32_0 : i32 + %9 = arith.divsi %8, %5 : i32 + %c128_i32 = arith.constant 128 : i32 + %10 = arith.muli %7, %c128_i32 : i32 + %c1_i64 = arith.constant 1 : i64 + %c0_i32 = arith.constant 0 : i32 + %11 = arith.addi %10, %c0_i32 : i32 + %c128_i64 = arith.constant 128 : i64 + %c0_i32_1 = arith.constant 0 : i32 + %c128_i64_2 = arith.constant 128 : i64 + %c0_i32_3 = arith.constant 0 : i32 + %c128_i64_4 = arith.constant 128 : i64 + %c0_i32_5 = arith.constant 0 : i32 + %12 = arith.addi %c0_i32_3, %c0_i32_5 : i32 + %c64_i64 = arith.constant 64 : i64 + %c0_i32_6 = arith.constant 0 : i32 + %c64_i64_7 = arith.constant 64 : i64 + %c8192_i32 = arith.constant 8192 : i32 + %13 = tt.get_program_id y : i32 + %c0_i32_8 = arith.constant 0 : i32 + %14 = arith.addi %c0_i32_8, %13 : i32 + %15 = arith.muli %14, %c8192_i32 : i32 + %16 = tt.addptr %arg0, %15 : !tt.ptr, i32 + %17 = tt.make_tensor_ptr %16, [%c128_i64_2, %c64_i64_7], [%c1_i64, %c128_i64_4], [%c0_i32_1, %c0_i32_6] {order = array} : > + %18 = tt.advance %17, [%10, %c0_i32_3] : > + %c0_i32_9 = arith.constant 0 : i32 + %c256_i64 = arith.constant 256 : i64 + %c0_i32_10 = arith.constant 0 : i32 + %19 = arith.addi %c0_i32_9, %c0_i32_10 : i32 + %c64_i64_11 = arith.constant 64 : i64 + %c0_i32_12 = arith.constant 0 : i32 + %c64_i64_13 = arith.constant 64 : i64 + %c128_i32_14 = arith.constant 128 : i32 + %20 = arith.muli %9, %c128_i32_14 : i32 + %c1_i64_15 = arith.constant 1 : i64 + %c0_i32_16 = arith.constant 0 : i32 + %21 = arith.addi %20, %c0_i32_16 : i32 + %c256_i64_17 = arith.constant 256 : i64 + %c0_i32_18 = arith.constant 0 : i32 + %c256_i64_19 = arith.constant 256 : i64 + %c16384_i32 = arith.constant 16384 : i32 + %22 = tt.get_program_id y : i32 + %c0_i32_20 = arith.constant 0 : i32 + %23 = arith.addi %c0_i32_20, %22 : i32 + %24 = arith.muli %23, %c16384_i32 : i32 + %25 = tt.addptr %arg1, %24 : !tt.ptr, i32 + %26 = tt.make_tensor_ptr %25, [%c64_i64_13, %c256_i64_19], [%c256_i64, %c1_i64_15], [%c0_i32_12, %c0_i32_18] {order = array} : > + %27 = tt.advance %26, [%c0_i32_9, %20] : > + %c0_i32_21 = arith.constant 0 : i32 + %c64_i32 = arith.constant 64 : i32 + %c32_i32 = arith.constant 32 : i32 + %28:3 = scf.for %arg3 = %c0_i32_21 to %c64_i32 step %c32_i32 iter_args(%arg4 = %18, %arg5 = %27, %arg6 = %cst) -> (!tt.ptr>, !tt.ptr>, tensor<128x128xf32>) : i32 { + %39 = tt.load %arg4 : !tt.ptr> + %c0_i32_35 = arith.constant 0 : i32 + %c32_i32_36 = arith.constant 32 : i32 + %40 = tt.advance %arg4, [%c0_i32_35, %c32_i32_36] : > + %41 = tt.load %arg5 : !tt.ptr> + %c32_i32_37 = arith.constant 32 : i32 + %c0_i32_38 = arith.constant 0 : i32 + %42 = tt.advance %arg5, [%c32_i32_37, %c0_i32_38] : > + %43 = arith.extsi %39 : tensor<128x32xi4> to tensor<128x32xi8> + %44 = arith.sitofp %43 : tensor<128x32xi8> to tensor<128x32xf32> + %45 = tt.dot %44, %41, %arg6 : tensor<128x32xf32> * tensor<32x128xf32> -> tensor<128x128xf32> + scf.yield %40, %42, %45 : !tt.ptr>, !tt.ptr>, tensor<128x128xf32> + } + %c128_i32_22 = arith.constant 128 : i32 + %29 = arith.muli %7, %c128_i32_22 : i32 + %c256_i64_23 = arith.constant 256 : i64 + %c0_i32_24 = arith.constant 0 : i32 + %30 = arith.addi %29, %c0_i32_24 : i32 + %c128_i64_25 = arith.constant 128 : i64 + %c0_i32_26 = arith.constant 0 : i32 + %c128_i64_27 = arith.constant 128 : i64 + %c128_i32_28 = arith.constant 128 : i32 + %31 = arith.muli %9, %c128_i32_28 : i32 + %c1_i64_29 = arith.constant 1 : i64 + %c0_i32_30 = arith.constant 0 : i32 + %32 = arith.addi %31, %c0_i32_30 : i32 + %c256_i64_31 = arith.constant 256 : i64 + %c0_i32_32 = arith.constant 0 : i32 + %c256_i64_33 = arith.constant 256 : i64 + %c32768_i32 = arith.constant 32768 : i32 + %33 = tt.get_program_id y : i32 + %c0_i32_34 = arith.constant 0 : i32 + %34 = arith.addi %c0_i32_34, %33 : i32 + %35 = arith.muli %34, %c32768_i32 : i32 + %36 = tt.addptr %arg2, %35 : !tt.ptr, i32 + %37 = tt.make_tensor_ptr %36, [%c128_i64_27, %c256_i64_33], [%c256_i64_23, %c1_i64_29], [%c0_i32_26, %c0_i32_32] {order = array} : > + %38 = tt.advance %37, [%29, %31] : > + tt.store %38, %28#2 : !tt.ptr> + tt.return + } +} diff --git a/xla/service/gpu/tests/int4_to_packed_int4_small.mlir b/xla/service/gpu/tests/int4_to_packed_int4_small.mlir new file mode 100644 index 00000000000000..a7323a4afaed8b --- /dev/null +++ b/xla/service/gpu/tests/int4_to_packed_int4_small.mlir @@ -0,0 +1,12 @@ +// RUN: xla-opt --int4-to-packed-int4-rewrite %s + +module { + tt.func @dot_test(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<16x16xi8> { + %c0 = arith.constant 0 : i32 + %c16 = arith.constant 16: i64 + %0 = tt.make_tensor_ptr %arg0, [%c16, %c16], [%c16, %c16], [%c0, %c0] {order = array} : > + %1 = tt.load %0 : !tt.ptr> + %2 = arith.extsi %1 : tensor<16x16xi4> to tensor<16x16xi8> + tt.return %2 : tensor<16x16xi8> + } +} From bb49f2b14505c7541bb3c82159c40f4ebe27ccf0 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Thu, 19 Dec 2024 05:12:45 -0800 Subject: [PATCH 4/6] [XLA:GPU] Remove the call to `FloatNormalization` preceding the normalization rewriter. It should no longer be necessary. This is part of a series of changes aimed at decreasing the slight added complexity introduced by the normalization rewriter. PiperOrigin-RevId: 707877388 --- xla/service/gpu/gpu_compiler.cc | 3 --- 1 file changed, 3 deletions(-) diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index f4e87d417eec5a..faeaa7a6c46679 100755 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -1584,9 +1584,6 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( if ((cuda_cc != nullptr && cuda_cc->IsAtLeast(se::CudaComputeCapability::AMPERE)) || rocm_cc != nullptr) { - // Triton compilation needs normalized operations on bf16 (i.e. converted - // to f32). - add_float_normalization(pipeline); pipeline.AddPass>(simplifier_options, gpu_version); pipeline.AddPass(/*is_layout_sensitive=*/true); From 03d02c57ce0d8de4e43bf5cd9506748cfb753d6f Mon Sep 17 00:00:00 2001 From: xla authors Date: Thu, 19 Dec 2024 05:20:20 -0800 Subject: [PATCH 5/6] Automated Code Change PiperOrigin-RevId: 707878810 --- xla/hlo/transforms/BUILD | 11 +++++++++++ xla/hlo/transforms/host_offload_legalize.cc | 1 - xla/hlo/transforms/host_offload_legalize.h | 2 ++ xla/hlo/transforms/host_offload_legalize_test.cc | 3 --- xla/hlo/transforms/host_offloader.cc | 8 ++------ xla/hlo/transforms/host_offloader.h | 3 +++ xla/hlo/transforms/host_offloader_test.cc | 1 - xla/hlo/transforms/memory_space_propagation.cc | 4 ++++ xla/hlo/transforms/memory_space_propagation.h | 6 ++++++ xla/hlo/transforms/memory_space_propagation_test.cc | 4 ++++ xla/hlo/transforms/operand_upcaster_test.cc | 3 +++ 11 files changed, 35 insertions(+), 11 deletions(-) diff --git a/xla/hlo/transforms/BUILD b/xla/hlo/transforms/BUILD index d895b3cc917306..84cf00b65702b1 100644 --- a/xla/hlo/transforms/BUILD +++ b/xla/hlo/transforms/BUILD @@ -1335,6 +1335,9 @@ cc_library( "//xla/hlo/analysis:hlo_dataflow_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", ], ) @@ -1346,6 +1349,10 @@ xla_cc_test( "//xla/hlo/parser:hlo_parser", "//xla/hlo/testlib:hlo_hardware_independent_test_base", "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", "@tsl//tsl/platform:test_main", ], ) @@ -1838,6 +1845,7 @@ cc_library( "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@tsl//tsl/platform:errors", @@ -1879,6 +1887,7 @@ cc_library( "//xla:side_effect_util", "//xla:status_macros", "//xla:util", + "//xla:xla_data_proto_cc", "//xla/hlo/analysis:hlo_alias_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", @@ -2274,10 +2283,12 @@ xla_cc_test( deps = [ ":operand_upcaster", "//xla:shape_util", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:hlo_hardware_independent_test_base", "//xla/hlo/utils:hlo_matchers", "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test_main", ], diff --git a/xla/hlo/transforms/host_offload_legalize.cc b/xla/hlo/transforms/host_offload_legalize.cc index 639e37874ceb4b..5e70dbb26c7d21 100644 --- a/xla/hlo/transforms/host_offload_legalize.cc +++ b/xla/hlo/transforms/host_offload_legalize.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include #include -#include #include #include "absl/algorithm/container.h" diff --git a/xla/hlo/transforms/host_offload_legalize.h b/xla/hlo/transforms/host_offload_legalize.h index a5d85fa40a8a5c..e08c842ee0bc68 100644 --- a/xla/hlo/transforms/host_offload_legalize.h +++ b/xla/hlo/transforms/host_offload_legalize.h @@ -17,8 +17,10 @@ #include #include +#include #include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/analysis/hlo_alias_analysis.h" #include "xla/hlo/pass/hlo_pass_interface.h" diff --git a/xla/hlo/transforms/host_offload_legalize_test.cc b/xla/hlo/transforms/host_offload_legalize_test.cc index 4aedc40b8ca2be..a37a73fc149f9f 100644 --- a/xla/hlo/transforms/host_offload_legalize_test.cc +++ b/xla/hlo/transforms/host_offload_legalize_test.cc @@ -16,12 +16,9 @@ limitations under the License. #include "xla/hlo/transforms/host_offload_legalize.h" #include -#include #include -#include #include -#include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_computation.h" diff --git a/xla/hlo/transforms/host_offloader.cc b/xla/hlo/transforms/host_offloader.cc index 7b798fe38eef7b..833fa176b78b00 100644 --- a/xla/hlo/transforms/host_offloader.cc +++ b/xla/hlo/transforms/host_offloader.cc @@ -15,15 +15,10 @@ limitations under the License. #include "xla/hlo/transforms/host_offloader.h" -#include -#include #include #include #include -#include #include -#include -#include #include #include "absl/algorithm/container.h" @@ -35,7 +30,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" -#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "xla/hlo/analysis/hlo_alias_analysis.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" @@ -56,6 +51,7 @@ limitations under the License. #include "xla/side_effect_util.h" #include "xla/status_macros.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" diff --git a/xla/hlo/transforms/host_offloader.h b/xla/hlo/transforms/host_offloader.h index 765b3c2709856e..8e79a449261783 100644 --- a/xla/hlo/transforms/host_offloader.h +++ b/xla/hlo/transforms/host_offloader.h @@ -18,8 +18,11 @@ #include #include #include +#include +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/analysis/hlo_alias_analysis.h" diff --git a/xla/hlo/transforms/host_offloader_test.cc b/xla/hlo/transforms/host_offloader_test.cc index 1452815127f1a7..d38526e93178af 100644 --- a/xla/hlo/transforms/host_offloader_test.cc +++ b/xla/hlo/transforms/host_offloader_test.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include #include #include diff --git a/xla/hlo/transforms/memory_space_propagation.cc b/xla/hlo/transforms/memory_space_propagation.cc index d0704df0e88af9..3dc14572dc408b 100644 --- a/xla/hlo/transforms/memory_space_propagation.cc +++ b/xla/hlo/transforms/memory_space_propagation.cc @@ -16,7 +16,11 @@ limitations under the License. #include "xla/hlo/transforms/memory_space_propagation.h" #include +#include +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/shape.h" #include "xla/shape_util.h" diff --git a/xla/hlo/transforms/memory_space_propagation.h b/xla/hlo/transforms/memory_space_propagation.h index bb0da70bf1a7fc..b3998f542d39f5 100644 --- a/xla/hlo/transforms/memory_space_propagation.h +++ b/xla/hlo/transforms/memory_space_propagation.h @@ -16,6 +16,12 @@ limitations under the License. #ifndef XLA_HLO_TRANSFORMS_MEMORY_SPACE_PROPAGATION_H_ #define XLA_HLO_TRANSFORMS_MEMORY_SPACE_PROPAGATION_H_ +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/analysis/hlo_dataflow_analysis.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" diff --git a/xla/hlo/transforms/memory_space_propagation_test.cc b/xla/hlo/transforms/memory_space_propagation_test.cc index 15cd6c4cd4cbff..a1252d596ee281 100644 --- a/xla/hlo/transforms/memory_space_propagation_test.cc +++ b/xla/hlo/transforms/memory_space_propagation_test.cc @@ -15,6 +15,10 @@ limitations under the License. #include "xla/hlo/transforms/memory_space_propagation.h" +#include +#include "absl/hash/hash.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" #include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" diff --git a/xla/hlo/transforms/operand_upcaster_test.cc b/xla/hlo/transforms/operand_upcaster_test.cc index 8a143b365af618..ed61bb63d2dad6 100644 --- a/xla/hlo/transforms/operand_upcaster_test.cc +++ b/xla/hlo/transforms/operand_upcaster_test.cc @@ -18,12 +18,15 @@ limitations under the License. #include #include +#include +#include #include "absl/strings/string_view.h" #include "absl/strings/substitute.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/primitive_util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" namespace xla { From 313d56fc66638fc32abdba49f2614b54df51f900 Mon Sep 17 00:00:00 2001 From: xla authors Date: Thu, 19 Dec 2024 06:08:08 -0800 Subject: [PATCH 6/6] Rollback breaking C API changes (TryGetKeyValue()). Reverts 926ef6acf5ef98b363127a115c4614ec817d4804 PiperOrigin-RevId: 707888995 --- xla/pjrt/c/CHANGELOG.md | 6 --- xla/pjrt/c/pjrt_c_api.h | 40 +------------------ xla/pjrt/c/pjrt_c_api_gpu_internal.cc | 6 +-- xla/pjrt/c/pjrt_c_api_helpers.cc | 38 ------------------ xla/pjrt/c/pjrt_c_api_helpers.h | 17 +++----- xla/pjrt/c/pjrt_c_api_helpers_test.cc | 8 ---- xla/pjrt/c/pjrt_c_api_test_base.cc | 4 +- xla/pjrt/c/pjrt_c_api_wrapper_impl.cc | 36 ++--------------- xla/pjrt/c/pjrt_c_api_wrapper_impl.h | 1 - xla/pjrt/distributed/client.cc | 12 ------ xla/pjrt/distributed/client.h | 4 -- xla/pjrt/distributed/client_server_test.cc | 14 ------- .../distributed/in_memory_key_value_store.cc | 12 ------ .../distributed/in_memory_key_value_store.h | 4 -- .../distributed/key_value_store_interface.h | 7 ---- xla/pjrt/pjrt_c_api_client.cc | 2 - xla/python/xla.cc | 15 ------- xla/python/xla_extension/__init__.pyi | 2 - 18 files changed, 14 insertions(+), 214 deletions(-) diff --git a/xla/pjrt/c/CHANGELOG.md b/xla/pjrt/c/CHANGELOG.md index d56741eb3500b0..5852c9a54dcc01 100644 --- a/xla/pjrt/c/CHANGELOG.md +++ b/xla/pjrt/c/CHANGELOG.md @@ -1,10 +1,4 @@ # PJRT C API changelog - -## 0.61 -* Added ``PJRT_KeyValueTryGet`` to the KV store interface, - which is non-blocking and immediately returns an error if the - key is not found. - ## 0.60 * Added ``PJRT_Client_CreateBuffersForAsyncHostToDevice`` and ``PJRT_AsyncHostToDeviceTransferManager_TransferRawDataToSubBuffer``. diff --git a/xla/pjrt/c/pjrt_c_api.h b/xla/pjrt/c/pjrt_c_api.h index f2fc3b1c507a3c..36d82b0787ba41 100644 --- a/xla/pjrt/c/pjrt_c_api.h +++ b/xla/pjrt/c/pjrt_c_api.h @@ -80,7 +80,7 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Extension_Base, next); // Changes include: // * Adding a new field to the PJRT_Api or argument structs // * Renaming a method or argument (doesn't affect ABI) -#define PJRT_API_MINOR 61 +#define PJRT_API_MINOR 60 // The plugin should set the major_version and minor_version of // PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in @@ -351,35 +351,6 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_KeyValueGetCallback_Args, typedef PJRT_Error* (*PJRT_KeyValueGetCallback)( PJRT_KeyValueGetCallback_Args* args); -// Same as KeyValueGet, but returns `NotFoundError` immediately if the key is -// not found. -typedef void (*PJRT_KeyValueTryGetCallback_ValueDeleter)(char* value); - -struct PJRT_KeyValueTryGetCallback_Args { - size_t struct_size; - PJRT_Extension_Base* extension_start; - const char* key; - size_t key_size; - PJRT_CallbackError* callback_error; - void* user_arg; - char* value; // out - size_t value_size; // out - // The caller needs to set a PJRT_KeyValueTryGetCallback_ValueDeleter to - // delete the value returned by PJRT_KeyValueTryGetCallback. The - // implementation is responsible for copying `value` and then calling - // value_deleter_callback. - PJRT_KeyValueTryGetCallback_ValueDeleter value_deleter_callback; // out -}; -PJRT_DEFINE_STRUCT_TRAITS(PJRT_KeyValueTryGetCallback_Args, - value_deleter_callback); - -// Requirements for PJRT_KeyValueTryGetCallback implementation: (1) Thread-safe. -// (2) The caller that provides the two callbacks is responsible for avoiding -// key collisions between different users of key-value store (i.e. between -// different plugins, but not between different nodes in one plugin). -typedef PJRT_Error* (*PJRT_KeyValueTryGetCallback)( - PJRT_KeyValueTryGetCallback_Args* args); - struct PJRT_KeyValuePutCallback_Args { size_t struct_size; PJRT_Extension_Base* extension_start; @@ -418,15 +389,8 @@ struct PJRT_Client_Create_Args { void* kv_put_user_arg; PJRT_Client* client; // out - - // Key-value try-get callback provided by the caller of PJRT_Client_Create. - // Same as key-value get callback, but returns `NotFoundError` immediately if - // the key is not found. - PJRT_KeyValueTryGetCallback kv_try_get_callback; - // Will be passed to `kv_try_get_callback` as `user_arg` argument. - void* kv_try_get_user_arg; }; -PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_Create_Args, kv_try_get_user_arg); +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_Create_Args, client); // Creates and initializes a new PJRT_Client and returns in `client`. typedef PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args); diff --git a/xla/pjrt/c/pjrt_c_api_gpu_internal.cc b/xla/pjrt/c/pjrt_c_api_gpu_internal.cc index 68d36fdb7f5c86..4f53c640a6a3dc 100644 --- a/xla/pjrt/c/pjrt_c_api_gpu_internal.cc +++ b/xla/pjrt/c/pjrt_c_api_gpu_internal.cc @@ -154,9 +154,9 @@ PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) { options.num_nodes = num_nodes; options.allowed_devices = visible_devices; options.platform_name = platform_name; - options.kv_store = pjrt::ToCppKeyValueStore( - args->kv_get_callback, args->kv_get_user_arg, args->kv_try_get_callback, - args->kv_try_get_user_arg, args->kv_put_callback, args->kv_put_user_arg); + options.kv_store = + pjrt::ToCppKeyValueStore(args->kv_get_callback, args->kv_get_user_arg, + args->kv_put_callback, args->kv_put_user_arg); options.enable_mock_nccl = enable_mock_nccl; options.mock_gpu_topology = mock_gpu_topology; PJRT_ASSIGN_OR_RETURN(std::unique_ptr client, diff --git a/xla/pjrt/c/pjrt_c_api_helpers.cc b/xla/pjrt/c/pjrt_c_api_helpers.cc index ca094063c412aa..cf92041af497d5 100644 --- a/xla/pjrt/c/pjrt_c_api_helpers.cc +++ b/xla/pjrt/c/pjrt_c_api_helpers.cc @@ -795,25 +795,6 @@ static PJRT_KeyValueGetCFunc ToKVGetCFunc( }; } -static PJRT_KeyValueTryGetCFunc ToKVTryGetCFunc( - xla::KeyValueStoreInterface* kv_store) { - return [kv_store](PJRT_KeyValueTryGetCallback_Args* args) -> PJRT_Error* { - absl::StatusOr output = - kv_store->TryGet(absl::string_view(args->key, args->key_size)); - if (!output.ok()) { - absl::string_view message = output.status().message(); - return (*args->callback_error)( - StatusCodeToPjrtErrorCode(output.status().code()), message.data(), - message.size()); - } - args->value = new char[output->size()]; - std::copy(output->begin(), output->end(), args->value); - args->value_size = output->size(); - args->value_deleter_callback = &PjRtValueDeleterCallback; - return nullptr; - }; -} - static PJRT_KeyValuePutCFunc ToKVPutCFunc( xla::KeyValueStoreInterface* kv_store) { return [kv_store](PJRT_KeyValuePutCallback_Args* args) -> PJRT_Error* { @@ -845,22 +826,6 @@ static PJRT_KeyValueGetCallback ToCKVGetCallback( }; } -static PJRT_KeyValueTryGetCallback ToCKVTryGetCallback( - PJRT_KeyValueTryGetCFunc* kv_try_get_c_func) { - return [](PJRT_KeyValueTryGetCallback_Args* args) -> PJRT_Error* { - PJRT_KeyValueTryGetCFunc* kv_try_get_c_func = - reinterpret_cast(args->user_arg); - if (kv_try_get_c_func == nullptr) { - absl::Status status = xla::InvalidArgument( - "got nullptr for PJRT_KeyValueTryGet_Args.user_arg"); - return (*args->callback_error)(StatusCodeToPjrtErrorCode(status.code()), - status.message().data(), - status.message().size()); - } - return (*kv_try_get_c_func)(args); - }; -} - static PJRT_KeyValuePutCallback ToCKVPutCallback( PJRT_KeyValuePutCFunc* kv_put_c_func) { return [](PJRT_KeyValuePutCallback_Args* args) -> PJRT_Error* { @@ -881,12 +846,9 @@ std::unique_ptr ConvertToCKeyValueCallbacks( std::shared_ptr kv_store) { auto kv_callback_data = std::make_unique(); kv_callback_data->kv_get_c_func = ToKVGetCFunc(kv_store.get()); - kv_callback_data->kv_try_get_c_func = ToKVTryGetCFunc(kv_store.get()); kv_callback_data->kv_put_c_func = ToKVPutCFunc(kv_store.get()); kv_callback_data->c_kv_get = ToCKVGetCallback(&kv_callback_data->kv_get_c_func); - kv_callback_data->c_kv_try_get = - ToCKVTryGetCallback(&kv_callback_data->kv_try_get_c_func); kv_callback_data->c_kv_put = ToCKVPutCallback(&kv_callback_data->kv_put_c_func); kv_callback_data->kv_store = std::move(kv_store); diff --git a/xla/pjrt/c/pjrt_c_api_helpers.h b/xla/pjrt/c/pjrt_c_api_helpers.h index baae41fbeca28d..f530b82f423573 100644 --- a/xla/pjrt/c/pjrt_c_api_helpers.h +++ b/xla/pjrt/c/pjrt_c_api_helpers.h @@ -218,9 +218,6 @@ int GetId(const PJRT_Api* api, PJRT_DeviceDescription* device_desc); using PJRT_KeyValueGetCFunc = std::function; -using PJRT_KeyValueTryGetCFunc = - std::function; - using PJRT_KeyValuePutCFunc = std::function; @@ -231,21 +228,17 @@ struct PJRT_KeyValueCallbackData { std::shared_ptr kv_store; - // kv_get_c_func, kv_try_get_c_func and kv_put_c_func are holding pointers to - // kv_store. + // kv_get_c_func and kv_put_c_func are holding pointers to kv_store. pjrt::PJRT_KeyValueGetCFunc kv_get_c_func; pjrt::PJRT_KeyValuePutCFunc kv_put_c_func; - // c_kv_get, c_kv_try_get and c_kv_put are holding pointers to kv_get_c_func, - // kv_try_get_c_func and kv_put_c_func. + // c_kv_get and c_kv_put are holding pointers to kv_get_c_func and + // kv_put_c_func. PJRT_KeyValueGetCallback c_kv_get; PJRT_KeyValuePutCallback c_kv_put; - pjrt::PJRT_KeyValueTryGetCFunc kv_try_get_c_func; - PJRT_KeyValueTryGetCallback c_kv_try_get; }; -// The returned &kv_get_c_func, &kv_try_get_c_func and &kv_put_c_func must be -// set as PJRT_Client_Create_Args.kv_get_user_arg, -// PJRT_Client_Create_Args.kv_try_get_user_arg and +// The returned &kv_get_c_func and &kv_put_c_func must be set as +// PJRT_Client_Create_Args.kv_get_user_arg and // PJRT_Client_Create_Args.kv_put_user_arg, respectively. The entire // PJRT_KeyValueCallbackData must be kept alive as long as c_kv_get and c_kv_put // may be called. diff --git a/xla/pjrt/c/pjrt_c_api_helpers_test.cc b/xla/pjrt/c/pjrt_c_api_helpers_test.cc index 6dfce81a1e4514..4b8a59287589ed 100644 --- a/xla/pjrt/c/pjrt_c_api_helpers_test.cc +++ b/xla/pjrt/c/pjrt_c_api_helpers_test.cc @@ -108,22 +108,14 @@ TEST(PjRtCApiHelperTest, Callback) { auto kv_callback_data = ConvertToCKeyValueCallbacks(kv_store); auto converted_kv_store = ToCppKeyValueStore( kv_callback_data->c_kv_get, &kv_callback_data->kv_get_c_func, - kv_callback_data->c_kv_try_get, &kv_callback_data->kv_try_get_c_func, kv_callback_data->c_kv_put, &kv_callback_data->kv_put_c_func); - auto v_not_found = converted_kv_store->Get("key", absl::Seconds(1)); - EXPECT_TRUE(absl::IsNotFound(v_not_found.status())) << v_not_found.status(); - auto s = converted_kv_store->Set("key", "value"); TF_EXPECT_OK(s); auto v = converted_kv_store->Get("key", absl::Seconds(1)); TF_EXPECT_OK(v.status()); EXPECT_EQ(*v, "value"); - - auto v_2 = converted_kv_store->TryGet("key"); - TF_EXPECT_OK(v.status()); - EXPECT_EQ(*v, "value"); } TEST(PjRtCApiHelperTest, ConvertToCLayoutFromStrides) { diff --git a/xla/pjrt/c/pjrt_c_api_test_base.cc b/xla/pjrt/c/pjrt_c_api_test_base.cc index f867846ebcbd54..9602813c573c52 100644 --- a/xla/pjrt/c/pjrt_c_api_test_base.cc +++ b/xla/pjrt/c/pjrt_c_api_test_base.cc @@ -47,11 +47,9 @@ PJRT_Client* CreateClient(const PJRT_Api* api) { create_args.create_options = nullptr; create_args.num_options = 0; create_args.kv_get_callback = nullptr; - create_args.kv_get_user_arg = nullptr; create_args.kv_put_callback = nullptr; create_args.kv_put_user_arg = nullptr; - create_args.kv_try_get_callback = nullptr; - create_args.kv_try_get_user_arg = nullptr; + create_args.kv_get_user_arg = nullptr; PJRT_Error* error = api->PJRT_Client_Create(&create_args); CHECK_EQ(error, nullptr); CHECK_NE(create_args.client, nullptr); diff --git a/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc b/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc index 222d689b3b68e8..ec697b08af7841 100644 --- a/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc +++ b/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc @@ -235,13 +235,9 @@ static absl::Status PopulateExecutableOutputMemoryKinds( class CApiKeyValueStore : public xla::KeyValueStoreInterface { public: CApiKeyValueStore(PJRT_KeyValueGetCallback c_get_callback, void* get_user_arg, - PJRT_KeyValueTryGetCallback c_try_get_callback, - void* try_get_user_arg, PJRT_KeyValuePutCallback c_put_callback, void* put_user_arg) : c_get_callback_(c_get_callback), get_user_arg_(get_user_arg), - c_try_get_callback_(c_try_get_callback), - try_get_user_arg_(try_get_user_arg), c_put_callback_(c_put_callback), put_user_arg_(put_user_arg) {} @@ -268,27 +264,6 @@ class CApiKeyValueStore : public xla::KeyValueStoreInterface { return result; } - absl::StatusOr TryGet(absl::string_view key) override { - PJRT_CallbackError callback_error = [](PJRT_Error_Code code, - const char* message, - size_t message_size) { - return new PJRT_Error{absl::Status(static_cast(code), - std::string(message, message_size))}; - }; - PJRT_KeyValueTryGetCallback_Args args; - args.key = key.data(); - args.key_size = key.size(); - args.callback_error = &callback_error; - args.user_arg = try_get_user_arg_; - std::unique_ptr error(c_try_get_callback_(&args)); - if (error != nullptr) { - return error->status; - } - auto result = std::string(args.value, args.value_size); - args.value_deleter_callback(args.value); - return result; - } - absl::Status Set(absl::string_view key, absl::string_view value) override { PJRT_CallbackError callback_error = [](PJRT_Error_Code code, const char* message, @@ -313,23 +288,18 @@ class CApiKeyValueStore : public xla::KeyValueStoreInterface { private: PJRT_KeyValueGetCallback c_get_callback_; void* get_user_arg_; - PJRT_KeyValueTryGetCallback c_try_get_callback_; - void* try_get_user_arg_; PJRT_KeyValuePutCallback c_put_callback_; void* put_user_arg_; }; std::shared_ptr ToCppKeyValueStore( PJRT_KeyValueGetCallback c_get_callback, void* get_user_arg, - PJRT_KeyValueTryGetCallback c_try_get_callback, void* try_get_user_arg, PJRT_KeyValuePutCallback c_put_callback, void* put_user_arg) { - if (c_get_callback == nullptr || c_try_get_callback == nullptr || - c_put_callback == nullptr) { + if (c_get_callback == nullptr || c_put_callback == nullptr) { return nullptr; } - return std::make_shared( - c_get_callback, get_user_arg, c_try_get_callback, try_get_user_arg, - c_put_callback, put_user_arg); + return std::make_shared(c_get_callback, get_user_arg, + c_put_callback, put_user_arg); } // ---------------------------------- Errors ----------------------------------- diff --git a/xla/pjrt/c/pjrt_c_api_wrapper_impl.h b/xla/pjrt/c/pjrt_c_api_wrapper_impl.h index 873845d3ac815f..0ebecc0c251734 100644 --- a/xla/pjrt/c/pjrt_c_api_wrapper_impl.h +++ b/xla/pjrt/c/pjrt_c_api_wrapper_impl.h @@ -464,7 +464,6 @@ PJRT_Client* CreateWrapperClient(std::unique_ptr cpp_client); // Helper functions for converting C key-value store callbacks to C++ callbacks. std::shared_ptr ToCppKeyValueStore( PJRT_KeyValueGetCallback c_get_callback, void* get_user_arg, - PJRT_KeyValueTryGetCallback c_try_get_callback, void* try_get_user_arg, PJRT_KeyValuePutCallback c_put_callback, void* put_user_arg); // A method that does not nothing other than returning a nullptr. Can be used as diff --git a/xla/pjrt/distributed/client.cc b/xla/pjrt/distributed/client.cc index 305afe7ae4c6d4..280c60873e9d07 100644 --- a/xla/pjrt/distributed/client.cc +++ b/xla/pjrt/distributed/client.cc @@ -26,7 +26,6 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" #include "absl/time/time.h" #include "absl/types/span.h" #include "grpcpp/channel.h" @@ -54,7 +53,6 @@ class DistributedRuntimeCoordinationServiceClient absl::Status Shutdown() override; absl::StatusOr BlockingKeyValueGet( absl::string_view key, absl::Duration timeout) override; - absl::StatusOr KeyValueTryGet(absl::string_view key) override; absl::StatusOr>> KeyValueDirGet(absl::string_view key) override; absl::Status KeyValueSet(absl::string_view key, @@ -146,12 +144,6 @@ DistributedRuntimeCoordinationServiceClient::BlockingKeyValueGet( return coord_agent_->GetKeyValue(key, timeout); } -absl::StatusOr -DistributedRuntimeCoordinationServiceClient::KeyValueTryGet( - absl::string_view key) { - return coord_agent_->TryGetKeyValue(key); -} - absl::StatusOr>> DistributedRuntimeCoordinationServiceClient::KeyValueDirGet( absl::string_view key) { @@ -224,10 +216,6 @@ class DistributedKeyValueStore : public KeyValueStoreInterface { return client_->BlockingKeyValueGet(absl::StrCat(prefix_, key), timeout); } - absl::StatusOr TryGet(absl::string_view key) override { - return client_->KeyValueTryGet(absl::StrCat(prefix_, key)); - } - absl::Status Set(absl::string_view key, absl::string_view value) override { return client_->KeyValueSet(absl::StrCat(prefix_, key), value); } diff --git a/xla/pjrt/distributed/client.h b/xla/pjrt/distributed/client.h index 58f4fe367681d2..e597ff158cc674 100644 --- a/xla/pjrt/distributed/client.h +++ b/xla/pjrt/distributed/client.h @@ -27,7 +27,6 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/string_view.h" #include "absl/time/time.h" #include "absl/types/span.h" #include "grpcpp/channel.h" @@ -117,9 +116,6 @@ class DistributedRuntimeClient { virtual absl::StatusOr BlockingKeyValueGet( absl::string_view key, absl::Duration timeout) = 0; - // Returns `NotFoundError` immediately if the key is not found. - virtual absl::StatusOr KeyValueTryGet(absl::string_view key) = 0; - // Get all key-value pairs under a directory (key). // A value is considered to be in the directory if its key is prefixed with // the directory. diff --git a/xla/pjrt/distributed/client_server_test.cc b/xla/pjrt/distributed/client_server_test.cc index baec103eced933..f5b7e656fe69a2 100644 --- a/xla/pjrt/distributed/client_server_test.cc +++ b/xla/pjrt/distributed/client_server_test.cc @@ -1029,20 +1029,6 @@ TEST_F(ClientServerTest, KeyValueSet_Duplicate_Overwrites) { EXPECT_EQ(result.value(), "overwritten_value"); } -TEST_F(ClientServerTest, KeyValueTryGet) { - StartService(/*num_nodes=*/1); - auto client = GetClient(/*node_id=*/0); - TF_ASSERT_OK(client->Connect()); - - ASSERT_THAT(client->KeyValueTryGet("test_key").status(), - StatusIs(absl::StatusCode::kNotFound)); - - TF_ASSERT_OK(client->KeyValueSet("test_key", "value")); - auto result = client->KeyValueTryGet("test_key"); - TF_ASSERT_OK(result.status()); - EXPECT_EQ(result.value(), "value"); -} - TEST_F(ClientServerTest, KeyValueDelete) { StartService(/*num_nodes=*/1); auto client = GetClient(/*node_id=*/0); diff --git a/xla/pjrt/distributed/in_memory_key_value_store.cc b/xla/pjrt/distributed/in_memory_key_value_store.cc index 49fc73ec87f163..70cc5360ecf7b3 100644 --- a/xla/pjrt/distributed/in_memory_key_value_store.cc +++ b/xla/pjrt/distributed/in_memory_key_value_store.cc @@ -20,7 +20,6 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" @@ -41,17 +40,6 @@ absl::StatusOr InMemoryKeyValueStore::Get(absl::string_view key, return kv_store_.find(key)->second; } -absl::StatusOr InMemoryKeyValueStore::TryGet( - absl::string_view key) { - absl::MutexLock lock(&mu_); - auto it = kv_store_.find(key); - if (it == kv_store_.end()) { - return absl::NotFoundError( - absl::StrCat(key, " is not found in the kv store.")); - } - return it->second; -} - absl::Status InMemoryKeyValueStore::Set(absl::string_view key, absl::string_view value) { absl::MutexLock lock(&mu_); diff --git a/xla/pjrt/distributed/in_memory_key_value_store.h b/xla/pjrt/distributed/in_memory_key_value_store.h index 13f50c722bd125..1530633a98b754 100644 --- a/xla/pjrt/distributed/in_memory_key_value_store.h +++ b/xla/pjrt/distributed/in_memory_key_value_store.h @@ -21,9 +21,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" -#include "absl/time/time.h" #include "xla/pjrt/distributed/key_value_store_interface.h" namespace xla { @@ -33,8 +31,6 @@ class InMemoryKeyValueStore : public KeyValueStoreInterface { absl::StatusOr Get(absl::string_view key, absl::Duration timeout) override; - absl::StatusOr TryGet(absl::string_view key) override; - absl::Status Set(absl::string_view key, absl::string_view value) override; private: diff --git a/xla/pjrt/distributed/key_value_store_interface.h b/xla/pjrt/distributed/key_value_store_interface.h index 312ebb8abb6463..29580fb86847b1 100644 --- a/xla/pjrt/distributed/key_value_store_interface.h +++ b/xla/pjrt/distributed/key_value_store_interface.h @@ -38,18 +38,11 @@ class KeyValueStoreInterface { virtual ~KeyValueStoreInterface() = default; // Blocking Get(). - // Useful for listening for a key-value pair that may be set later on. // There are no concurrency guarantees. To avoid a race / impose an ordering // on potentially concurrent ops (e.g. set, delete), use WaitAtBarrier(). virtual absl::StatusOr Get(absl::string_view key, absl::Duration timeout) = 0; - // Returns `NotFoundError` immediately if the key is not found. - // Useful for checking key existence. - // There are no concurrency guarantees. To avoid a race / impose an ordering - // on potentially concurrent ops (e.g. set, delete), use WaitAtBarrier(). - virtual absl::StatusOr TryGet(absl::string_view key) = 0; - virtual absl::Status Set(absl::string_view key, absl::string_view value) = 0; }; diff --git a/xla/pjrt/pjrt_c_api_client.cc b/xla/pjrt/pjrt_c_api_client.cc index 1f65b13109afc6..8855ef33620e5f 100644 --- a/xla/pjrt/pjrt_c_api_client.cc +++ b/xla/pjrt/pjrt_c_api_client.cc @@ -2578,8 +2578,6 @@ absl::StatusOr> WrapClientAroundCApi( kv_callback_data = pjrt::ConvertToCKeyValueCallbacks(kv_store); init_args.kv_get_callback = kv_callback_data->c_kv_get; init_args.kv_get_user_arg = &kv_callback_data->kv_get_c_func; - init_args.kv_try_get_callback = kv_callback_data->c_kv_try_get; - init_args.kv_try_get_user_arg = &kv_callback_data->kv_try_get_c_func; init_args.kv_put_callback = kv_callback_data->c_kv_put; init_args.kv_put_user_arg = &kv_callback_data->kv_put_c_func; } diff --git a/xla/python/xla.cc b/xla/python/xla.cc index e30af5d4e5e43d..51c96229493e4c 100644 --- a/xla/python/xla.cc +++ b/xla/python/xla.cc @@ -672,21 +672,6 @@ NB_MODULE(xla_extension, m) { return nb::bytes(result.data(), result.size()); }, nb::arg("key"), nb::arg("timeout_in_ms")) - .def( - "key_value_try_get", - [](DistributedRuntimeClient& client, std::string key) { - nb::gil_scoped_release gil_release; - return xla::ValueOrThrow(client.KeyValueTryGet(key)); - }, - nb::arg("key")) - .def( - "key_value_try_get_bytes", - [](DistributedRuntimeClient& client, std::string key) -> nb::bytes { - nb::gil_scoped_release gil_release; - std::string result = xla::ValueOrThrow(client.KeyValueTryGet(key)); - return nb::bytes(result.data(), result.size()); - }, - nb::arg("key")) .def( "wait_at_barrier", [](DistributedRuntimeClient& client, std::string barrier_id, diff --git a/xla/python/xla_extension/__init__.pyi b/xla/python/xla_extension/__init__.pyi index 5fa885f9f92255..2e3862285898f2 100644 --- a/xla/python/xla_extension/__init__.pyi +++ b/xla/python/xla_extension/__init__.pyi @@ -830,8 +830,6 @@ class DistributedRuntimeClient: def blocking_key_value_get_bytes( self, key: str, timeout_in_ms: int ) -> _Status: ... - def key_value_try_get(self, key: str) -> _Status: ... - def key_value_try_get_bytes(self, key: str) -> _Status: ... def key_value_dir_get(self, key: str) -> _Status: ... def key_value_dir_get_bytes(self, key: str) -> _Status: ... def key_value_set(self, key: str, value: str,