diff --git a/.gitignore b/.gitignore index 0180cd9112..d533f6f099 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,8 @@ python/build/ python/triton.egg-info/ python/triton/_C/libtriton.pyd python/triton/_C/libtriton.so +python/triton/backends/cuda +python/triton/backends/xpu # Python caches __pycache__/ diff --git a/.gitmodules b/.gitmodules index 3a989c6cc9..e69de29bb2 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,10 +0,0 @@ -[submodule "third_party/intel_xpu_backend"] - path = third_party/intel_xpu_backend - url = http://github.com/intel/intel-xpu-backend-for-triton -[submodule "third_party/amd_hip_backend"] - path = third_party/amd_hip_backend - url = https://github.com/ROCmSoftwarePlatform/triton - branch = third_party_backend_2 -[submodule "third_party/triton_shared"] - path = third_party/triton_shared - url = https://github.com/microsoft/triton-shared diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1daba92fc8..09e48877c1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -46,11 +46,11 @@ repos: exclude: | (?x)( ^include/triton/external/| - ^python/triton/third_party/ + ^third_party/ ) exclude: | (?x)( ^include/triton/external/| - ^python/triton/third_party/ + ^third_party/ ) diff --git a/CMakeLists.txt b/CMakeLists.txt index 3822383ac3..212502e20f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -43,105 +43,17 @@ endif() # Compiler flags include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) +set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17") # Third-party include_directories(${PYBIND11_INCLUDE_DIR}) -set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17 -fvisibility=hidden -fvisibility-inlines-hidden") -if(APPLE) - set(CMAKE_OSX_DEPLOYMENT_TARGET 11.6) -endif() # ######### # LLVM # ######### if(NOT MLIR_DIR) - if(NOT LLVM_LIBRARY_DIR) - if(WIN32) - find_package(LLVM 13 REQUIRED COMPONENTS nvptx amdgpu) - - include_directories(${LLVM_INCLUDE_DIRS}) - separate_arguments(LLVM_DEFINITIONS_LIST NATIVE_COMMAND ${LLVM_DEFINITIONS}) - add_definitions(${LLVM_DEFINITIONS_LIST}) - - llvm_map_components_to_libnames(LLVM_LIBRARIES support core - NVPTXInfo nvptxcodegen - AMDGPUInfo AMDGPUcodegen - ) - else() - find_package(LLVM 11 REQUIRED COMPONENTS "nvptx;amdgpu") - endif() - - message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}") - - # FindLLVM outputs LLVM_LIBRARY_DIRS but we expect LLVM_LIBRARY_DIR here - set(LLVM_LIBRARY_DIR ${LLVM_LIBRARY_DIRS}) - - if(APPLE) - set(CMAKE_OSX_DEPLOYMENT_TARGET "10.14") - endif() - - # sometimes we don't want to use llvm-config, since it may have been downloaded for some specific linux distros - else() - set(LLVM_LDFLAGS "-L${LLVM_LIBRARY_DIR}") - set(LLVM_LIBRARIES - LLVMNVPTXCodeGen - LLVMNVPTXDesc - LLVMNVPTXInfo - LLVMAMDGPUDisassembler - LLVMMCDisassembler - LLVMAMDGPUCodeGen - LLVMMIRParser - LLVMGlobalISel - LLVMSelectionDAG - LLVMipo - LLVMInstrumentation - LLVMVectorize - LLVMLinker - LLVMIRReader - LLVMAsmParser - LLVMFrontendOpenMP - LLVMAsmPrinter - LLVMDebugInfoDWARF - LLVMCodeGen - LLVMTarget - LLVMScalarOpts - LLVMInstCombine - LLVMAggressiveInstCombine - LLVMTransformUtils - LLVMBitWriter - LLVMAnalysis - LLVMProfileData - LLVMObject - LLVMTextAPI - LLVMBitReader - LLVMAMDGPUAsmParser - LLVMMCParser - LLVMAMDGPUDesc - LLVMAMDGPUUtils - LLVMMC - LLVMDebugInfoCodeView - LLVMDebugInfoMSF - LLVMCore - LLVMRemarks - LLVMBitstreamReader - LLVMBinaryFormat - LLVMAMDGPUInfo - LLVMSupport - LLVMDemangle - LLVMPasses - LLVMAnalysis - LLVMTransformUtils - LLVMScalarOpts - LLVMTransformUtils - LLVMipo - LLVMObjCARCOpts - LLVMCoroutines - LLVMAnalysis - ) - endif() - set(MLIR_DIR ${LLVM_LIBRARY_DIR}/cmake/mlir) endif() @@ -160,39 +72,6 @@ set(GenISAIntrinsics_LIBRARY ) message(STATUS "GenISAIntrinsics_LDFLAGS: ${GenISAIntrinsics_LDFLAGS}") -# Python module -if(TRITON_BUILD_PYTHON_MODULE) - message(STATUS "Adding Python module") - set(PYTHON_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/src) - set(PYTHON_SRC ${PYTHON_SRC_PATH}/main.cc - ${PYTHON_SRC_PATH}/ir.cc - ${PYTHON_SRC_PATH}/passes.cc - ${PYTHON_SRC_PATH}/interpreter.cc - ${PYTHON_SRC_PATH}/llvm.cc - ${PYTHON_SRC_PATH}/nvidia.cc) - include_directories("." ${PYTHON_SRC_PATH}) - - if(PYTHON_INCLUDE_DIRS) - include_directories(${PYTHON_INCLUDE_DIRS}) - else() - find_package(Python3 REQUIRED COMPONENTS Development Interpreter) - include_directories(${Python3_INCLUDE_DIRS}) - link_directories(${Python3_LIBRARY_DIRS}) - link_libraries(${Python3_LIBRARIES}) - add_link_options(${Python3_LINK_OPTIONS}) - endif() -endif() - -# # Triton -# file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc) -# if (WIN32 AND TRITON_BUILD_PYTHON_MODULE) -# Python3_add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC}) -# set_target_properties(triton PROPERTIES SUFFIX ".pyd") -# set_target_properties(triton PROPERTIES PREFIX "lib") -# else() -# add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC}) -# endif() - # MLIR find_package(MLIR REQUIRED CONFIG PATHS ${MLIR_DIR}) @@ -206,6 +85,7 @@ include(AddMLIR) # Disable warnings that show up in external code (gtest;pybind11) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default") +include_directories(".") include_directories(${MLIR_INCLUDE_DIRS}) include_directories(${LLVM_INCLUDE_DIRS}) include_directories(${PROJECT_SOURCE_DIR}/include) @@ -229,19 +109,53 @@ if (NOT WIN32 AND NOT APPLE) link_libraries(stdc++fs) endif() + +# ----- + +# ------ if(TRITON_BUILD_PYTHON_MODULE) - add_library(triton SHARED ${PYTHON_SRC}) + message(STATUS "Adding Python module") + set(PYTHON_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/src) + include_directories(${PYTHON_SRC_PATH}) + + if(PYTHON_INCLUDE_DIRS) + include_directories(${PYTHON_INCLUDE_DIRS}) + else() + find_package(Python3 REQUIRED COMPONENTS Development Interpreter) + include_directories(${Python3_INCLUDE_DIRS}) + link_directories(${Python3_LIBRARY_DIRS}) + link_libraries(${Python3_LIBRARIES}) + add_link_options(${Python3_LINK_OPTIONS}) + endif() + + set(TRITON_CODEGEN_BACKENDS "xpu") + foreach(CODEGEN_BACKEND ${TRITON_CODEGEN_BACKENDS}) + add_subdirectory(third_party/${CODEGEN_BACKEND}) + endforeach() + set(TRITON_LIBRARIES + TritonIR TritonAnalysis TritonTransforms + TritonToTritonGPU + TritonGPUIR TritonGPUTransforms - TritonNvidiaGPUTransforms TritonLLVMIR + TritonNvidiaGPUIR + MLIRAMDGPUDialect + TritonAnalysis + NVGPUToLLVM + TritonNvidiaGPUTransforms + TritonGPUToLLVM TritonSPIRV - ${dialect_libs} - ${conversion_libs} + MLIRNVVMDialect + MLIRNVVMToLLVMIRTranslation + MLIRGPUToNVVMTransforms + MLIRGPUToGPURuntimeTransforms + MLIRGPUTransforms # optimizations + MLIRControlFlowToLLVM MLIRBytecodeWriter MLIRPass MLIRTransforms @@ -249,33 +163,39 @@ if(TRITON_BUILD_PYTHON_MODULE) MLIRSupport MLIRTargetLLVMIRExport MLIRMathToLLVM - MLIRNVVMToLLVMIRTranslation MLIRROCDLToLLVMIRTranslation MLIRGENXToLLVMIRTranslation + MLIRGPUDialect MLIRIR + + # LLVM + LLVMPasses + LLVMX86CodeGen + LLVMX86AsmParser + LLVMNVPTXCodeGen + # LLVMNVPTXAsmPrinter + LLVMAMDGPUCodeGen + LLVMAMDGPUAsmParser + ) + # Define triton library + add_library(triton SHARED ${PYTHON_SRC_PATH}/main.cc + ${PYTHON_SRC_PATH}/ir.cc + ${PYTHON_SRC_PATH}/passes.cc + ${PYTHON_SRC_PATH}/interpreter.cc + ${PYTHON_SRC_PATH}/llvm.cc + ${CMAKE_CURRENT_SOURCE_DIR}/third_party/xpu/triton_xpu.cc) + + # Link triton with its dependencies + target_link_libraries(triton PUBLIC ${TRITON_LIBRARIES}) if(WIN32) - target_link_libraries(triton PRIVATE ${LLVM_LIBRARIES} ${LLVM_SPIRV_LIBRARY} - ${GenISAIntrinsics_LIBRARY} - ${CMAKE_DL_LIBS} ${TRITON_LIBRARIES} - ) - elseif(APPLE) - target_link_libraries(triton ${LLVM_LIBRARIES} ${LLVM_SPIRV_LIBRARY} - ${GenISAIntrinsics_LIBRARY} z - ${TRITON_LIBRARIES} - ) + target_link_libraries(triton PRIVATE ${CMAKE_DL_LIBS}) else() - target_link_libraries(triton ${LLVM_LIBRARIES} z - ${TRITON_LIBRARIES} ${LLVM_SPIRV_LIBRARY} ${GenISAIntrinsics_LIBRARY} - ) + target_link_libraries(triton PRIVATE z) endif() - - target_link_options(triton PRIVATE ${LLVM_LDFLAGS} ${LLVM_SPIRV_LDFLAGS} ${GenISAIntrinsics_LDFLAGS}) -endif() - -if(UNIX AND NOT APPLE) - set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,--exclude-libs,ALL") + target_link_options(triton PRIVATE ${LLVM_LDFLAGS} ${GenISAIntrinsics_LDFLAGS}) + set_target_properties(triton PROPERTIES INTERFACE_LINK_LIBRARIES "") endif() if(TRITON_BUILD_PYTHON_MODULE AND NOT WIN32) @@ -286,15 +206,7 @@ if(TRITON_BUILD_PYTHON_MODULE AND NOT WIN32) set(PYTHON_LDFLAGS "-undefined dynamic_lookup -flto") endif() - target_link_libraries(triton ${CUTLASS_LIBRARIES} ${PYTHON_LDFLAGS}) -endif() - -list(LENGTH TRITON_CODEGEN_BACKENDS CODEGEN_BACKENDS_LEN) -if (${CODEGEN_BACKENDS_LEN} GREATER 0) - set(PYTHON_THIRD_PARTY_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/triton/third_party) - foreach(CODEGEN_BACKEND ${TRITON_CODEGEN_BACKENDS}) - add_subdirectory(third_party/${CODEGEN_BACKEND}) - endforeach() + target_link_libraries(triton ${PYTHON_LDFLAGS}) endif() add_subdirectory(bin) diff --git a/include/triton/Target/PTX/TmaMetadata.h b/include/triton/Target/PTX/TmaMetadata.h index a36f8f8cd7..eb11a74693 100644 --- a/include/triton/Target/PTX/TmaMetadata.h +++ b/include/triton/Target/PTX/TmaMetadata.h @@ -24,7 +24,7 @@ #ifndef TRITON_TARGET_PTX_TMAMETADATA_H #define TRITON_TARGET_PTX_TMAMETADATA_H -#include "python/triton/third_party/cuda/include/cuda.h" +#include "third_party/cuda/backend/include/cuda.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Debug.h" #include "llvm/Support/Format.h" diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index dcd8c50121..ff9bf462aa 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -597,19 +597,6 @@ bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) { !srcTy.getElementType().isF32(); } -bool isSingleValue(Value value) { - // Don't consider load as expensive if it is loading a scalar. - if (auto tensorTy = value.getType().dyn_cast()) - return tensorTy.getNumElements() == 1; - // TODO: Handle other cases. - // For example, when ptr is a tensor of single value. - // It means that ptr is a resultant of broadcast or generated through - // a chain of broadcast and other operations. - // Rematerialize it without considering contiguous memory access pattern is - // fine. - return true; -} - namespace { /// A data structure similar to SetVector but maintains diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index b3b8a1f8ec..2d74417ca4 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -343,6 +343,19 @@ std::optional inferDstEncoding(Operation *op, Attribute encoding) { return encoding; } +bool isSingleValue(Value value) { + // Don't consider load as expensive if it is loading a scalar. + if (auto tensorTy = value.getType().dyn_cast()) + return tensorTy.getNumElements() == 1; + // TODO: Handle other cases. + // For example, when ptr is a tensor of single value. + // It means that ptr is a resultant of broadcast or generated through + // a chain of broadcast and other operations. + // Rematerialize it without considering contiguous memory access pattern is + // fine. + return true; +} + bool isExpensiveLoadOrStore(Operation *op) { // Case 1: Pointer of tensor is always expensive auto operandType = op->getOperand(0).getType(); diff --git a/python/MANIFEST.in b/python/MANIFEST.in index c8c8189198..e69de29bb2 100644 --- a/python/MANIFEST.in +++ b/python/MANIFEST.in @@ -1,5 +0,0 @@ -graft src -graft triton/third_party -graft triton/tools -graft triton/runtime/backends/ -graft triton/language/extra diff --git a/python/setup.py b/python/setup.py index 1313e3f007..a23cf2b11b 100644 --- a/python/setup.py +++ b/python/setup.py @@ -146,9 +146,7 @@ def download_and_copy(src_path, variable, version, url_func): if arch == "x86_64": arch = "64" url = url_func(arch, version) - dst_prefix = os.path.join(base_dir, "triton") - dst_suffix = os.path.join("third_party", "cuda", src_path) - dst_path = os.path.join(dst_prefix, dst_suffix) + dst_path = os.path.join(base_dir, os.pardir, "third_party", "cuda", "backend", src_path) is_linux = platform.system() == "Linux" download = False if is_linux: @@ -164,6 +162,7 @@ def download_and_copy(src_path, variable, version, url_func): file.extractall(path=temp_dir) src_path = os.path.join(temp_dir, src_path) os.makedirs(os.path.split(dst_path)[0], exist_ok=True) + print(f'copy {src_path} to {dst_path} ...') shutil.copy(src_path, dst_path) @@ -273,16 +272,17 @@ def build_extension(self, ext): cfg = get_build_type() build_args = ["--config", cfg] - codegen_backends = get_codegen_backends() - if len(codegen_backends) > 0: - all_codegen_backends = ';'.join(codegen_backends) - cmake_args += ["-DTRITON_CODEGEN_BACKENDS=" + all_codegen_backends] + # third-party backend + + # codegen_backends = get_codegen_backends() + # if len(codegen_backends) > 0: + # all_codegen_backends = ';'.join(codegen_backends) + # cmake_args += ["-DTRITON_CODEGEN_BACKENDS=" + all_codegen_backends] if platform.system() == "Windows": cmake_args += [f"-DCMAKE_RUNTIME_OUTPUT_DIRECTORY_{cfg.upper()}={extdir}"] if sys.maxsize > 2**32: cmake_args += ["-A", "x64"] - build_args += ["--", "/m"] else: cmake_args += ["-DCMAKE_BUILD_TYPE=" + cfg] max_jobs = os.getenv("MAX_JOBS", str(2 * os.cpu_count())) @@ -346,6 +346,14 @@ def build_extension(self, ext): f"https://anaconda.org/nvidia/cuda-nvdisasm/12.3.52/download/linux-{arch}/cuda-nvdisasm-{version}-0.tar.bz2", ) +plugins = ["xpu"] +for plugin in plugins: + src_path = os.path.join(os.pardir, "third_party", plugin, "backend") + dst_path = os.path.join(os.path.dirname(__file__), "triton", "backends", plugin) + if os.path.exists(dst_path): + shutil.rmtree(dst_path) + shutil.copytree(src_path, dst_path) + setup( name=os.environ.get("TRITON_WHEEL_NAME", "triton"), version="2.1.0" + os.environ.get("TRITON_WHEEL_VERSION_SUFFIX", ""), @@ -356,19 +364,21 @@ def build_extension(self, ext): packages=[ "triton", "triton/_C", - "triton/common", "triton/compiler", - "triton/compiler/backends", "triton/language", "triton/language/extra", "triton/ops", "triton/ops/blocksparse", "triton/runtime", - "triton/runtime/backends", - "triton/third_party", + "triton/backends", + "triton/backends/xpu", "triton/tools", ], install_requires=["filelock"], + package_data={ + "triton/tools": ["compile.h", "compile.c"], + "triton/backends/xpu": ["bin/*", "lib/*", "include/*"], + }, include_package_data=True, ext_modules=[CMakeExtension("triton", "triton/_C/")], cmdclass={"build_ext": CMakeBuild, "build_py": CMakeBuildPy, "clean": CMakeClean}, diff --git a/python/src/ir.cc b/python/src/ir.cc index 6aa402fe54..e7a9ed4369 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -189,11 +189,11 @@ void init_triton_ir(py::module &&m) { m.def("load_dialects", [](mlir::MLIRContext &context) { mlir::DialectRegistry registry; - registry.insert(); + registry.insert< + mlir::triton::TritonDialect, mlir::triton::gpu::TritonGPUDialect, + mlir::math::MathDialect, mlir::arith::ArithDialect, + mlir::index::IndexDialect, mlir::scf::SCFDialect, mlir::gpu::GPUDialect, + mlir::cf::ControlFlowDialect, mlir::LLVM::LLVMDialect>(); mlir::registerBuiltinDialectTranslation(registry); mlir::registerLLVMDialectTranslation(registry); context.appendDialectRegistry(registry); @@ -1527,9 +1527,18 @@ void init_triton_ir(py::module &&m) { .def(py::init()) .def("enable_debug", [](mlir::PassManager &self) { + auto *context = self.getContext(); + context->printOpOnDiagnostic(true); + context->printStackTraceOnDiagnostic(true); + context->disableMultithreading(); + context->getDiagEngine().registerHandler( + [](mlir::Diagnostic &diag) { + llvm::outs() << diag << "\n"; + return mlir::success(); + }); + if (!::triton::tools::getBoolEnv("MLIR_ENABLE_DUMP")) return; - self.getContext()->disableMultithreading(); auto printingFlags = mlir::OpPrintingFlags(); printingFlags.elideLargeElementsAttrs(16); printingFlags.enableDebugInfo(); diff --git a/python/src/llvm.cc b/python/src/llvm.cc index d4893b9ef1..06abc79e79 100644 --- a/python/src/llvm.cc +++ b/python/src/llvm.cc @@ -15,6 +15,7 @@ #include "llvm/Passes/OptimizationLevel.h" #include "llvm/Passes/PassBuilder.h" #include "llvm/Support/CodeGen.h" +#include "llvm/Support/TargetSelect.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Transforms/IPO/AlwaysInliner.h" #include "llvm/Transforms/InstCombine/InstCombine.h" @@ -104,6 +105,14 @@ void init_triton_llvm(py::module &&m) { py::class_(m, "context", py::module_local()) .def(py::init<>()); + py::class_(m, "function_list") + .def( + "__iter__", + [](llvm::Module::FunctionListType &s) { + return py::make_iterator(s.begin(), s.end()); + }, + py::keep_alive<0, 1>()); + py::class_(m, "module", py::module_local()) .def( "__str__", @@ -113,7 +122,23 @@ void init_triton_llvm(py::module &&m) { os << *self; return os.str(); }, - ret::take_ownership); + ret::take_ownership) + .def( + "get_functions", + [](llvm::Module *mod) -> llvm::Module::FunctionListType & { + return mod->getFunctionList(); + }, + ret::reference_internal); + + py::class_(m, "function", py::module_local()) + .def("set_calling_conv", &llvm::Function::setCallingConv) + .def("add_fn_attr", [](llvm::Function *fn, std::string &name, + std::string &val) { fn->addFnAttr(name, val); }) + .def("has_public_visibility", + [](llvm::Function *fn) { + return fn->getVisibility() == llvm::GlobalValue::DefaultVisibility; + }) + .def("is_declaration", &llvm::Function::isDeclaration); // optimization levels py::class_(m, "optimization_level", @@ -125,9 +150,12 @@ void init_triton_llvm(py::module &&m) { m.attr("OPTIMIZE_Os") = (llvm::OptimizationLevel::Os); m.attr("OPTIMIZE_Oz") = (llvm::OptimizationLevel::Oz); - m.def("to_module", [](mlir::ModuleOp &mod, llvm::LLVMContext &ctx) { - return mlir::translateModuleToLLVMIR(mod, ctx); - }); + m.def( + "to_module", + [](mlir::ModuleOp &mod, llvm::LLVMContext &ctx) { + return mlir::translateModuleToLLVMIR(mod, ctx); + }, + py::keep_alive<0, 2>()); m.def("optimize_module", [](llvm::Module *mod, const llvm::OptimizationLevel &opt) { @@ -227,6 +255,17 @@ void init_triton_llvm(py::module &&m) { mod->setDataLayout(layout); }); + m.def("init_targets", []() { + static std::once_flag init_flag; + std::call_once(init_flag, []() { + llvm::InitializeAllTargetInfos(); + llvm::InitializeAllTargets(); + llvm::InitializeAllTargetMCs(); + llvm::InitializeAllAsmParsers(); + llvm::InitializeAllAsmPrinters(); + }); + }); + m.def("link_extern_lib", [](llvm::Module *mod, std::string path) { llvm::SMDiagnostic err; auto &ctx = mod->getContext(); diff --git a/python/src/main.cc b/python/src/main.cc index 58e2566250..36f5b2679f 100644 --- a/python/src/main.cc +++ b/python/src/main.cc @@ -6,7 +6,7 @@ void init_triton_ir(pybind11::module &&m); void init_triton_llvm(pybind11::module &&m); void init_triton_interpreter(pybind11::module &&m); void init_triton_passes(pybind11::module &&m); -void init_triton_nvidia(pybind11::module &&m); +void init_triton_xpu(pybind11::module &&m); PYBIND11_MODULE(libtriton, m) { m.doc() = "Python bindings to the C++ Triton API"; @@ -15,5 +15,5 @@ PYBIND11_MODULE(libtriton, m) { init_triton_passes(m.def_submodule("passes")); init_triton_interpreter(m.def_submodule("interpreter")); init_triton_llvm(m.def_submodule("llvm")); - init_triton_nvidia(m.def_submodule("nvidia")); + init_triton_xpu(m.def_submodule("xpu")); } diff --git a/python/src/passes.cc b/python/src/passes.cc index 5125bc2238..2f582a2bf1 100644 --- a/python/src/passes.cc +++ b/python/src/passes.cc @@ -3,6 +3,8 @@ #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "passes.h" +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/Membar.h" #include "triton/Conversion/TritonToTritonGPU/Passes.h" #include "triton/Dialect/Triton/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" @@ -11,6 +13,14 @@ namespace py = pybind11; +void init_triton_analysis(py::module &&m) { + py::class_(m, "allocation", py::module_local()) + .def(py::init()); + py::class_(m, "membar", py::module_local()) + .def(py::init()) + .def("run", &mlir::ModuleMembarAnalysis::run); +} + void init_triton_passes_common(py::module &&m) { using namespace mlir; ADD_PASS_WRAPPER_0("add_sccp", createSCCPPass); @@ -51,6 +61,7 @@ void init_triton_passes_ttgpuir(py::module &&m) { void init_triton_passes_convert(py::module &&m) { using namespace mlir; ADD_PASS_WRAPPER_0("add_scf_to_cf", createConvertSCFToCFPass); + ADD_PASS_WRAPPER_0("add_cf_to_llvmir", createConvertControlFlowToLLVMPass); ADD_PASS_WRAPPER_0("add_index_to_llvmir", createConvertIndexToLLVMPass); ADD_PASS_WRAPPER_0("add_arith_to_llvmir", createArithToLLVMConversionPass); } @@ -61,6 +72,7 @@ void init_triton_passes_llvmir(py::module &&m) { } void init_triton_passes(py::module &&m) { + init_triton_analysis(m.def_submodule("analysis")); init_triton_passes_common(m.def_submodule("common")); init_triton_passes_convert(m.def_submodule("convert")); init_triton_passes_ttir(m.def_submodule("ttir")); diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 2168f9dfcd..0a1560d0a3 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -10,9 +10,18 @@ import triton import triton.language as tl -from triton.common.build import is_hip, is_spirv from triton.runtime.jit import JITFunction, TensorWrapper, reinterpret + +def is_hip(): + return triton.runtime.driver.get_current_target()[0] == "hip" + + +def is_spirv(): + import torch + return torch.xpu.is_available() + + int_dtypes = ['int8', 'int16', 'int32', 'int64'] uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64'] float_dtypes = ['float16', 'float32', 'float64'] @@ -3380,9 +3389,8 @@ def _kernel(dst): @pytest.mark.parametrize("dtype_str, expr, lib_path", [('int32', 'math.ffs', ''), ('float32', 'math.log2', ''), - ('float32', 'math.scalbn', ''), - ('float32', 'math.pow', tl.math.libdevice_path()), - ('float64', 'math.pow_dtype', tl.math.libdevice_path()), + ('float32', 'math.scalbn', ''), ('float32', 'math.pow', ''), + ('float64', 'math.pow_dtype', ''), ('float64', 'math.norm4d', '')]) @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_math_tensor(dtype_str, expr, lib_path, num_ctas, device): @@ -3441,7 +3449,7 @@ def kernel(X, Y, BLOCK: tl.constexpr): @pytest.mark.parametrize("dtype_str, expr, lib_path", [('float32', 'math.pow', ''), ('float64', 'math.pow_dtype', ''), - ('float64', 'math.pow', tl.math.libdevice_path())]) + ('float64', 'math.pow', '')]) @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_math_scalar(dtype_str, expr, lib_path, num_ctas, device): diff --git a/python/test/unit/language/test_line_info.py b/python/test/unit/language/test_line_info.py index 42c1f6a08a..bf2d84bb31 100644 --- a/python/test/unit/language/test_line_info.py +++ b/python/test/unit/language/test_line_info.py @@ -6,7 +6,7 @@ import triton import triton.language as tl -from triton.common.backend import path_to_spirvdis +from triton.backends.xpu.compiler import _path_to_binary @triton.jit @@ -76,7 +76,7 @@ def kernel_dot_combine(x): def extract_file_lines(spv): - dis = path_to_spirvdis() + dis, _ = _path_to_binary("spirv-dis") fd, path = tempfile.mkstemp() with open(fd, 'wb') as spvbin: spvbin.write(spv) @@ -132,7 +132,7 @@ def check_file_lines(file_lines, file_name, lineno, should_contain=True): @pytest.mark.parametrize("func", func_types) def test_line_info(func: str): try: - _ = path_to_spirvdis() + _, _ = _path_to_binary("spirv-dis") except BaseException: pytest.skip("spirv-dis is not available") diff --git a/python/test/unit/runtime/test_driver.py b/python/test/unit/runtime/test_driver.py index b63927d89b..103b2ef520 100644 --- a/python/test/unit/runtime/test_driver.py +++ b/python/test/unit/runtime/test_driver.py @@ -11,4 +11,4 @@ def test_is_lazy(): assert isinstance(triton.runtime.driver, getattr(mod, "LazyProxy")) assert triton.runtime.driver._obj is None utils = triton.runtime.driver.utils # noqa: F841 - assert issubclass(triton.runtime.driver._obj.__class__, getattr(mod, "DriverBase")) + assert issubclass(triton.runtime.driver._obj.__class__, getattr(triton.backends.driver, "DriverBase")) diff --git a/python/test/unit/tools/test_aot.py b/python/test/unit/tools/test_aot.py index 613836d24e..5ec600e92e 100644 --- a/python/test/unit/tools/test_aot.py +++ b/python/test/unit/tools/test_aot.py @@ -7,7 +7,7 @@ import numpy as np import triton -from triton.common import cuda_include_dir, libcuda_dirs +from triton.backends.cuda.driver import include_dir, library_dir kernel_utils_src = """ import triton @@ -99,13 +99,13 @@ def kernel(C, A, B, M, N, K, def gen_kernel_library(dir, libname): c_files = glob.glob(os.path.join(dir, "*.c")) subprocess.run( - ["gcc"] + c_files + ["-I", cuda_include_dir(), "-c", "-fPIC"], + ["gcc"] + c_files + ["-I", include_dir[0], "-c", "-fPIC"], check=True, cwd=dir, ) o_files = glob.glob(os.path.join(dir, "*.o")) subprocess.run( - ["gcc"] + o_files + ["-shared", "-o", libname, "-L", libcuda_dirs()[0]], + ["gcc"] + o_files + ["-shared", "-o", libname, "-L", library_dir[0]], check=True, cwd=dir, ) @@ -176,9 +176,9 @@ def gen_test_bin(dir, M, N, K, exe="test", algo_id=0): ["gcc"] + [ "test.c", "-I", - cuda_include_dir(), + include_dir[0], "-L", - libcuda_dirs()[0], + library_dir[0], "-l", "cuda", "-L", diff --git a/python/triton/backends/__init__.py b/python/triton/backends/__init__.py new file mode 100644 index 0000000000..cabcec71b8 --- /dev/null +++ b/python/triton/backends/__init__.py @@ -0,0 +1,50 @@ +import os +import importlib +import inspect +from dataclasses import dataclass +from .driver import DriverBase +from .compiler import BaseBackend + + +def _load_module(name, path): + spec = importlib.util.spec_from_file_location(name[:-3], path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def _find_concrete_subclasses(module, base_class): + ret = [] + for attr_name in dir(module): + attr = getattr(module, attr_name) + if isinstance(attr, type) and issubclass(attr, base_class) and not inspect.isabstract(attr): + ret.append(attr) + if len(ret) == 0: + raise RuntimeError(f"Found 0 concrete subclasses of {base_class} in {module}: {ret}") + if len(ret) > 1: + raise RuntimeError(f"Found >1 concrete subclasses of {base_class} in {module}: {ret}") + return ret[0] + + +@dataclass(frozen=True) +class Backend: + compiler: BaseBackend = None + driver: DriverBase = None + + +def _discover_backends(): + backends = dict() + root = os.path.dirname(__file__) + for name in os.listdir(root): + if not os.path.isdir(os.path.join(root, name)): + continue + if name.startswith('__'): + continue + compiler = _load_module(name, os.path.join(root, name, 'compiler.py')) + driver = _load_module(name, os.path.join(root, name, 'driver.py')) + backends[name] = Backend(_find_concrete_subclasses(compiler, BaseBackend), + _find_concrete_subclasses(driver, DriverBase)) + return backends + + +backends = _discover_backends() diff --git a/python/triton/backends/compiler.py b/python/triton/backends/compiler.py new file mode 100644 index 0000000000..0655b3fa5a --- /dev/null +++ b/python/triton/backends/compiler.py @@ -0,0 +1,64 @@ +from abc import ABCMeta, abstractmethod, abstractclassmethod +import os +import subprocess +import re + + +class BaseBackend(metaclass=ABCMeta): + + def __init__(self, target: tuple) -> None: + self.target = target + assert self.supports_target(target) + + @staticmethod + def _path_to_binary(binary: str): + base_dir = os.path.join(os.path.dirname(__file__), os.pardir) + paths = [ + os.environ.get(f"TRITON_{binary.upper()}_PATH", ""), + os.path.join(base_dir, "third_party", "cuda", "bin", binary), + ] + for p in paths: + bin = p.split(" ")[0] + if os.path.exists(bin) and os.path.isfile(bin): + result = subprocess.check_output([bin, "--version"], stderr=subprocess.STDOUT) + if result is not None: + version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE) + if version is not None: + return p, version.group(1) + raise RuntimeError(f"Cannot find {binary}") + + @abstractclassmethod + def supports_target(target: tuple): + raise NotImplementedError + + @abstractmethod + def hash(self) -> str: + """Returns a unique identifier for this backend""" + raise NotImplementedError + + @abstractmethod + def parse_options(self, options: dict) -> object: + """ + Converts an `options` dictionary into an arbitrary object and returns it. + This function may contain target-specific heuristics and check the legality of the provided options + """ + raise NotImplementedError + + @abstractmethod + def add_stages(self, stages: dict, options: object) -> None: + """ + Populates `stages` dictionary with entries of the form: + ir_name [str] => Function[(src: str, metadata: dict) -> str|bytes] + The value of each entry may populate a `metadata` dictionary. + Stages will be run sequentially (in inseriton order) and can communicate using `metadata`. + All stages are expected to return a `str` object, except for the last stage which returns + a `bytes` object for execution by the launcher. + """ + raise NotImplementedError + + @abstractmethod + def load_dialects(self, context): + """ + Load additional MLIR dialects into the provided `context` + """ + raise NotImplementedError diff --git a/python/triton/backends/driver.py b/python/triton/backends/driver.py new file mode 100644 index 0000000000..e66442943b --- /dev/null +++ b/python/triton/backends/driver.py @@ -0,0 +1,34 @@ +from abc import ABCMeta, abstractmethod, abstractclassmethod + + +class DriverBase(metaclass=ABCMeta): + + @abstractclassmethod + def is_active(self): + pass + + @abstractmethod + def get_current_target(self): + pass + + def __init__(self) -> None: + pass + + +class GPUDriver(DriverBase): + + def __init__(self): + # TODO: support other frameworks than torch + import torch + self.get_device_capability = torch.cuda.get_device_capability + try: + from torch._C import _cuda_getCurrentRawStream + self.get_current_stream = _cuda_getCurrentRawStream + except ImportError: + self.get_current_stream = lambda idx: torch.cuda.current_stream(idx).cuda_stream + self.get_current_device = torch.cuda.current_device + self.set_current_device = torch.cuda.set_device + + # TODO: remove once TMA is cleaned up + def assemble_tensormap_to_arg(self, tensormaps_info, args): + return args diff --git a/python/triton/common/__init__.py b/python/triton/common/__init__.py deleted file mode 100644 index dfb6f8870e..0000000000 --- a/python/triton/common/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .build import _build, cuda_include_dir, libcuda_dirs - -__all__ = ["_build", "libcuda_dirs", "cuda_include_dir"] diff --git a/python/triton/common/backend.py b/python/triton/common/backend.py deleted file mode 100644 index f416036459..0000000000 --- a/python/triton/common/backend.py +++ /dev/null @@ -1,205 +0,0 @@ -import functools -import hashlib -import importlib -import importlib.util -import os -import re -import subprocess -import traceback -from typing import Dict - -from ..runtime.driver import DriverBase - -TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -TRITON_VERSION = "2.1.0" - - -class BaseBackend: - - def __init__(self, device_type: str) -> None: - self.device_type = device_type - - def add_stages(self, arch, extern_libs, stages): - """ - Custom the arch, extern_libs and stages per backend specific requirement - """ - raise NotImplementedError - - def add_meta_info(self, ir, cur_module, next_module, metadata, asm): - """ - Custom the ir, module, metadata and asm per backend specific requirement - """ - raise NotImplementedError - - def get_load_binary_fn(self): - """ - Return a callable to load binary - """ - raise NotImplementedError - - def get_driver(self) -> DriverBase: - """ - Get the backend driver. Please refer to "DriverBase" for more details - """ - raise NotImplementedError - - def get_stream(self): - """ - Get stream for current device - """ - raise NotImplementedError - - def get_device_properties(self, device): - raise NotImplementedError - - def get_current_device(self): - """ - Get current device - """ - raise NotImplementedError - - def set_current_device(self, device): - """ - Set current device as the given device - """ - raise NotImplementedError - - def get_kernel_bin(self): - raise NotImplementedError - - def make_launcher_stub(self, name, signature, constants): - """ - Generate the launcher stub to launch the kernel - """ - raise NotImplementedError - - def get_architecture_descriptor(self, **kwargs): - """ - Get the architecture descriptor the backend - """ - raise NotImplementedError - - @classmethod - def create_backend(cls, device_type: str): - return cls(device_type) - - -_backends: Dict[str, BaseBackend] = {} - - -def register_backend(device_type: str, backend_cls: type): - if device_type not in _backends: - _backends[device_type] = backend_cls.create_backend(device_type) - - -def get_backend(device_type: str): - if device_type not in _backends: - device_backend_package_name = f"...third_party.{device_type}" - if importlib.util.find_spec(device_backend_package_name, package=__spec__.name): - try: - importlib.import_module(device_backend_package_name, package=__spec__.name) - except Exception: - traceback.print_exc() - else: - return None - return _backends[device_type] if device_type in _backends else None - - -def _path_to_binary(binary: str): - base_dir = os.path.join(os.path.dirname(__file__), os.pardir) - paths = [ - os.environ.get(f"TRITON_{binary.upper()}_PATH", ""), - os.path.join(base_dir, "third_party", "cuda", "bin", binary), - ] - - for p in paths: - bin = p.split(" ")[0] - if os.path.exists(bin) and os.path.isfile(bin): - result = subprocess.check_output([bin, "--version"], stderr=subprocess.STDOUT) - if result is not None: - version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE) - if version is not None: - return p, version.group(1) - raise RuntimeError(f"Cannot find {binary}") - - -def _path_to_spirv_binary(binary: str): - base_dir = os.path.join(os.path.dirname(__file__), os.pardir) - paths = [ - os.environ.get(f"TRITON_{binary.upper()}_PATH", ""), - os.path.join(base_dir, "third_party", "spirv", "bin", binary) - ] - - for p in paths: - bin = p.split(" ")[0] - if os.path.exists(bin) and os.path.isfile(bin): - result = subprocess.check_output([bin, "--version"], stderr=subprocess.STDOUT) - if result is not None: - return p - raise RuntimeError(f"Cannot find {binary}") - - -@functools.lru_cache() -def path_to_ptxas(): - return _path_to_binary("ptxas") - - -@functools.lru_cache() -def path_to_cuobjdump(): - return _path_to_binary("cuobjdump") - - -@functools.lru_cache() -def path_to_nvdisasm(): - return _path_to_binary("nvdisasm") - - -@functools.lru_cache() -def path_to_spirvdis(): - return _path_to_spirv_binary("spirv-dis") - - -@functools.lru_cache() -def compute_core_version_key(): - import pkgutil - contents = [] - # frontend - with open(__file__, "rb") as f: - contents += [hashlib.sha1(f.read()).hexdigest()] - # compiler - compiler_path = os.path.join(TRITON_PATH, 'compiler') - backends_path = os.path.join(TRITON_PATH, 'compiler', 'backends') - for lib in pkgutil.iter_modules([compiler_path, backends_path]): - with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: - contents += [hashlib.sha1(f.read()).hexdigest()] - # backend - libtriton_hash = hashlib.sha1() - with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f: - while True: - chunk = f.read(1024**2) - if not chunk: - break - libtriton_hash.update(chunk) - contents.append(libtriton_hash.hexdigest()) - # language - language_path = os.path.join(TRITON_PATH, 'language') - for lib in pkgutil.iter_modules([language_path]): - with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: - contents += [hashlib.sha1(f.read()).hexdigest()] - return '-'.join(TRITON_VERSION) + '-'.join(contents) - - -_cached_cuda_version_key = None - - -def get_cuda_version_key(): - global _cached_cuda_version_key - if _cached_cuda_version_key is None: - key = compute_core_version_key() - try: - ptxas = path_to_ptxas()[0] - ptxas_version = subprocess.check_output([ptxas, "--version"]) - except RuntimeError: - ptxas_version = b"NO_PTXAS" - _cached_cuda_version_key = key + '-' + hashlib.sha1(ptxas_version).hexdigest() - return _cached_cuda_version_key diff --git a/python/triton/common/build.py b/python/triton/common/build.py deleted file mode 100644 index a2a866eb95..0000000000 --- a/python/triton/common/build.py +++ /dev/null @@ -1,173 +0,0 @@ -import contextlib -import functools -import io -import os -import shutil -import subprocess -import sys -import sysconfig - -import setuptools - - -# TODO: is_hip shouldn't be here -def is_hip(): - import torch - return torch.version.hip is not None - - -# TODO: properly set is_spirv -def is_spirv(): - import torch - return torch.xpu.is_available() - - -@functools.lru_cache() -def libcuda_dirs(): - env_libcuda_path = os.getenv("TRITON_LIBCUDA_PATH") - if env_libcuda_path: - return [env_libcuda_path] - - libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode() - # each line looks like the following: - # libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1 - locs = [line.split()[-1] for line in libs.splitlines() if "libcuda.so" in line] - dirs = [os.path.dirname(loc) for loc in locs] - env_ld_library_path = os.getenv("LD_LIBRARY_PATH") - if env_ld_library_path and not dirs: - dirs = [dir for dir in env_ld_library_path.split(":") if os.path.exists(os.path.join(dir, "libcuda.so"))] - msg = 'libcuda.so cannot found!\n' - if locs: - msg += 'Possible files are located at %s.' % str(locs) - msg += 'Please create a symlink of libcuda.so to any of the file.' - else: - msg += 'Please make sure GPU is setup and then run "/sbin/ldconfig"' - msg += ' (requires sudo) to refresh the linker cache.' - assert any(os.path.exists(os.path.join(path, 'libcuda.so')) for path in dirs), msg - return dirs - - -@functools.lru_cache() -def rocm_path_dir(): - return os.getenv("ROCM_PATH", default="/opt/rocm") - - -@functools.lru_cache() -def ze_path_dir(): - return os.getenv("ZE_PATH", default="/usr/local") - - -@contextlib.contextmanager -def quiet(): - old_stdout, old_stderr = sys.stdout, sys.stderr - sys.stdout, sys.stderr = io.StringIO(), io.StringIO() - try: - yield - finally: - sys.stdout, sys.stderr = old_stdout, old_stderr - - -@functools.lru_cache() -def cuda_include_dir(): - base_dir = os.path.join(os.path.dirname(__file__), os.path.pardir) - cuda_path = os.path.join(base_dir, "third_party", "cuda") - return os.path.join(cuda_path, "include") - - -def _build(name, src, srcdir): - if is_spirv(): - ze_lib_dir = os.path.join(ze_path_dir(), "lib") - ze_include_dir = os.path.join(ze_path_dir(), "include/level_zero") - elif is_hip(): - hip_lib_dir = os.path.join(rocm_path_dir(), "lib") - hip_include_dir = os.path.join(rocm_path_dir(), "include") - else: - cuda_lib_dirs = libcuda_dirs() - cu_include_dir = cuda_include_dir() - suffix = sysconfig.get_config_var('EXT_SUFFIX') - so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix)) - # try to avoid setuptools if possible - cc = os.environ.get("CC") - if cc is None: - # TODO: support more things here. - clang = shutil.which("clang") - gcc = shutil.which("gcc") - cc = gcc if gcc is not None else clang - if cc is None: - raise RuntimeError("Failed to find C compiler. Please specify via CC environment variable.") - # This function was renamed and made public in Python 3.10 - if hasattr(sysconfig, 'get_default_scheme'): - scheme = sysconfig.get_default_scheme() - else: - scheme = sysconfig._get_default_scheme() - # 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install - # path changes to include 'local'. This change is required to use triton with system-wide python. - if scheme == 'posix_local': - scheme = 'posix_prefix' - py_include_dir = sysconfig.get_paths(scheme=scheme)["include"] - - if is_spirv(): - icpx = None - cxx = os.environ.get("CXX") - if cxx is None: - clangpp = shutil.which("clang++") - icpx = shutil.which("icpx") - cxx = icpx if icpx is not None else clangpp - import numpy as np - numpy_include_dir = np.get_include() - if icpx is not None: - ret = subprocess.check_call([ - cxx, src, "-fsycl", "-std=c++17", "-g", f"-I{ze_include_dir}", f"-I{py_include_dir}", - f"-I{numpy_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", f"-L{ze_lib_dir}", "-lze_loader", "-o", so - ]) - else: - ret = subprocess.check_call([ - cxx, src, "-std=c++17", "-g", f"-I{ze_include_dir}", f"-I{py_include_dir}", f"-I{numpy_include_dir}", - f"-I{srcdir}", "-shared", "-fPIC", f"-L{ze_lib_dir}", "-lze_loader", "-o", so - ]) - elif is_hip(): - ret = subprocess.check_call([ - cc, src, f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", - f"-L{hip_lib_dir}", "-lamdhip64", "-o", so - ]) - else: - cc_cmd = [ - cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda", - "-o", so - ] - cc_cmd += [f"-L{dir}" for dir in cuda_lib_dirs] - ret = subprocess.check_call(cc_cmd) - - if ret == 0: - return so - # fallback on setuptools - extra_compile_args = [] - library_dirs = cuda_lib_dirs - include_dirs = [srcdir, cu_include_dir] - libraries = ['cuda'] - # extra arguments - extra_link_args = [] - # create extension module - ext = setuptools.Extension( - name=name, - language='c', - sources=[src], - include_dirs=include_dirs, - extra_compile_args=extra_compile_args + ['-O3'], - extra_link_args=extra_link_args, - library_dirs=library_dirs, - libraries=libraries, - ) - # build extension module - args = ['build_ext'] - args.append('--build-temp=' + srcdir) - args.append('--build-lib=' + srcdir) - args.append('-q') - args = dict( - name=name, - ext_modules=[ext], - script_args=args, - ) - with quiet(): - setuptools.setup(**args) - return so diff --git a/python/triton/compiler/__init__.py b/python/triton/compiler/__init__.py index 558b5158fe..33b55c27a0 100644 --- a/python/triton/compiler/__init__.py +++ b/python/triton/compiler/__init__.py @@ -1,4 +1,4 @@ -from .compiler import (CompiledKernel, ASTSource, compile, AttrsDescriptor) +from .compiler import CompiledKernel, ASTSource, compile, AttrsDescriptor, make_backend from .errors import CompilationError -__all__ = ["compile", "ASTSource", "AttrsDescriptor", "CompiledKernel", "CompilationError"] +__all__ = ["compile", "make_backend", "ASTSource", "AttrsDescriptor", "CompiledKernel", "CompilationError"] diff --git a/python/triton/compiler/backends/__init__.py b/python/triton/compiler/backends/__init__.py deleted file mode 100644 index 3d71f15560..0000000000 --- a/python/triton/compiler/backends/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from .cuda import CUDABackend -from .xpu import XPUBackend - - -def make_backend(target): - return {"cuda": CUDABackend, "xpu": XPUBackend}[target[0]](target) diff --git a/python/triton/compiler/backends/cuda.py b/python/triton/compiler/backends/cuda.py deleted file mode 100644 index c293a6ab62..0000000000 --- a/python/triton/compiler/backends/cuda.py +++ /dev/null @@ -1,291 +0,0 @@ -from triton.common.backend import BaseBackend -from dataclasses import dataclass -from ...common.backend import get_cuda_version_key, path_to_ptxas -from ..._C.libtriton import ir, passes, nvidia, llvm -import functools -from typing import Any -from ..make_launcher import make_stub -from ..utils import get_ids_of_tensormaps, parse_tma_info -import hashlib -import re -import tempfile -import signal -import os -import subprocess -from pathlib import Path - - -@functools.lru_cache() -def ptx_get_version(cuda_version) -> int: - ''' - Get the highest PTX version supported by the current CUDA driver. - ''' - assert isinstance(cuda_version, str) - major, minor = map(int, cuda_version.split('.')) - if major == 12: - return 80 + minor - if major == 11: - return 70 + minor - if major == 10: - return 63 + minor - raise RuntimeError("Triton only support CUDA 10.0 or higher") - - -@dataclass(frozen=True) -class CUDAOptions: - num_warps: int = 4 - num_ctas: int = 1 - num_stages: int = 3 - cluster_dims: tuple = (1, 1, 1) - ptx_version: int = None - enable_warp_specialization: bool = False - enable_persistent: bool = False - optimize_epilogue: bool = False - enable_fp_fusion: bool = True - allow_fp8e4nv: bool = False - max_num_imprecise_acc_default: bool = None - extern_libs: dict = None - debug: bool = False - - def __post_init__(self): - default_libdir = Path(__file__).parent.parent.parent / 'third_party' / 'cuda' / 'lib' - extern_libs = dict() if self.extern_libs is None else dict(self.extern_libs) - if not extern_libs.get('libdevice', None): - extern_libs['libdevice'] = str(default_libdir / 'libdevice.10.bc') - object.__setattr__(self, 'extern_libs', tuple(extern_libs.items())) - assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \ - "num_warps must be a power of 2" - - def hash(self): - key = '_'.join([f'{name}-{val}' for name, val in self.__dict__.items()]) - return hashlib.md5(key.encode("utf-8")).hexdigest() - - -class CUDABackend(BaseBackend): - - def __init__(self, device_type: tuple) -> None: - super().__init__(device_type) - self.capability = device_type[1] - assert isinstance(self.capability, int) - - def parse_options(self, opts) -> Any: - args = {k: opts[k] for k in CUDAOptions.__dataclass_fields__.keys() if k in opts} - args["allow_fp8e4nv"] = self.capability >= 89 - args["max_num_imprecise_acc_default"] = 2**30 if self.capability == 90 else 0 - options = CUDAOptions(**args) - assert options.num_ctas == 1 or self.capability >= 90, \ - f"num_ctas > 1 supported only on SM90+. Got num_ctas={options.num_ctas}, SM{self.capability}" - return options - - @staticmethod - def load_dialects(ctx): - nvidia.load_dialects(ctx) - - @staticmethod - def make_ttir(mod, metadata, opt): - pm = ir.pass_manager(mod.context) - pm.enable_debug() - passes.common.add_inliner(pm) - passes.ttir.add_combine(pm) - passes.common.add_canonicalizer(pm) - passes.ttir.add_reorder_broadcast(pm) - passes.common.add_cse(pm) - passes.common.add_licm(pm) - passes.common.add_symbol_dce(pm) - pm.run(mod) - return mod - - @staticmethod - def make_ttgir(mod, metadata, opt, capability): - cluster_info = nvidia.ClusterInfo() - if opt.cluster_dims is not None: - cluster_info.clusterDimX = opt.cluster_dims[0] - cluster_info.clusterDimY = opt.cluster_dims[1] - cluster_info.clusterDimZ = opt.cluster_dims[2] - # TTIR -> TTGIR - pm = ir.pass_manager(mod.context) - pm.enable_debug() - passes.ttir.add_convert_to_ttgpuir(pm, opt.num_warps, 32, opt.num_ctas, capability) - # optimize TTGIR - passes.ttgpuir.add_coalesce(pm) - # TODO(Qingyi): Move PlanCTAPass to the front of CoalescePass - nvidia.passes.ttnvgpuir.add_plan_cta(pm, cluster_info) - nvidia.passes.ttgpuir.add_rewrite_tensor_pointer(pm, capability) - nvidia.passes.ttnvgpuir.add_plan_cta(pm, cluster_info) - passes.ttgpuir.add_remove_layout_conversions(pm) - passes.ttgpuir.add_optimize_thread_locality(pm) - passes.ttgpuir.add_accelerate_matmul(pm, capability) - passes.ttgpuir.add_remove_layout_conversions(pm) - if opt.optimize_epilogue: - passes.ttgpuir.add_optimize_epilogue(pm) - passes.ttgpuir.add_optimize_dot_operands(pm) - passes.common.add_cse(pm) - # `num_warps` does not mean the total number of warps of a CTA when - # warp specialization is enabled. - # it's the responsibility of the compiler to figure out the exact - # `num_warps` to use. - # TODO: support the case where `num_warps` from user is not 4. - ws_enabled = False - if capability // 10 >= 9 and opt.enable_warp_specialization and opt.num_warps == 4: - nvidia.passes.ttnvgpuir.add_wsfeasibility_checking(pm, capability) - pm.run(mod) - ws_enabled = nvidia.passes.ttnvgpuir.is_ws_supported(mod) - pm = ir.pass_manager(mod.context) - pm.enable_debug() - metadata["ws_enabled"] = ws_enabled - if ws_enabled: - nvidia.passes.ttnvgpuir.add_wsdecomposing(pm, capability) - nvidia.passes.ttnvgpuir.add_wspipeline(pm, opt.num_stages, opt.num_warps, capability) - nvidia.passes.ttnvgpuir.add_wsmutex(pm, capability) - nvidia.passes.ttnvgpuir.add_wsmaterialization(pm, capability) - passes.common.add_licm(pm) - passes.common.add_cse(pm) - else: - passes.ttgpuir.add_pipeline(pm, opt.num_stages, opt.num_warps, opt.num_ctas, capability) - nvidia.passes.ttnvgpuir.add_materialize_load_store(pm, opt.num_warps, capability) - if capability // 10 <= 8: - passes.ttgpuir.add_prefetch(pm) - passes.ttgpuir.add_optimize_dot_operands(pm) - passes.ttgpuir.add_remove_layout_conversions(pm) - passes.ttgpuir.add_decompose_conversions(pm) - nvidia.passes.ttnvgpuir.add_wsfixup_missing_attrs(pm) - passes.ttgpuir.add_reorder_instructions(pm) - passes.common.add_cse(pm) - passes.common.add_symbol_dce(pm) - if capability // 10 >= 9: - nvidia.passes.ttnvgpuir.add_fence_insertion(pm) - nvidia.passes.ttnvgpuir.add_wsfixup_missing_attrs(pm) - passes.common.add_canonicalizer(pm) - pm.run(mod) - metadata["cluster_dims"] = (cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ) - return mod - - @staticmethod - def make_llir(src, metadata, options, capability): - # warp-specialization mutates num_warps - num_warp_groups = src.get_int_attr("triton_gpu.num-warp-groups-per-cta") - if num_warp_groups is not None: - metadata["num_warps"] *= num_warp_groups - mod = src - # TritonGPU -> LLVM-IR (MLIR) - tma_infos = nvidia.TMAInfos() - pm = ir.pass_manager(mod.context) - pm.enable_debug() - passes.convert.add_scf_to_cf(pm) - passes.convert.add_index_to_llvmir(pm) - nvidia.passes.ttgpuir.add_to_llvmir(pm, capability, tma_infos) - if metadata["ws_enabled"]: - passes.common.add_licm(pm) - passes.common.add_cse(pm) - nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm) - passes.convert.add_arith_to_llvmir(pm) - passes.common.add_canonicalizer(pm) - passes.common.add_cse(pm) - passes.common.add_symbol_dce(pm) - if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0": - passes.llvmir.add_di_scope(pm) - pm.run(mod) - # LLVM-IR (MLIR) -> LLVM-IR (LLVM) - nvidia.init_llvm() - context = llvm.context() - llvm_mod = llvm.to_module(mod, context) - nvidia.set_nvvm_reflect_ftz(llvm_mod) - if options.extern_libs: - for name, path in options.extern_libs: - llvm.link_extern_lib(llvm_mod, path) - llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3) - # Get some metadata - if len(tma_infos) > 0: - metadata["tensormaps_info"] = parse_tma_info(tma_infos, metadata["ids_of_folded_args"]) - for i, _ in enumerate(metadata["tensormaps_info"]): - metadata["tensormaps_info"][i].ids_of_folded_args = metadata["ids_of_folded_args"] - metadata["ids_of_tensormaps"] = get_ids_of_tensormaps(metadata.get("tensormaps_info", None)) - metadata["shared"] = src.get_int_attr("triton_gpu.shared") - ret = str(llvm_mod) - del llvm_mod - del context - return ret - - @staticmethod - def make_ptx(src, metadata, opt, capability): - proc = 'sm_90a' if capability == 90 else f'sm_{capability}' - ret = llvm.translate_to_asm(src, 'nvptx64-nvidia-cuda', proc, '', ['nvptx-short-ptr'], opt.enable_fp_fusion, - False) - # Find kernel names (there should only be one) - names = re.findall(r".visible .entry ([a-zA-Z_][a-zA-Z0-9_]*)", ret) - assert len(names) == 1 - metadata["name"] = names[0] - # post-process - ptx_version = opt.ptx_version - if ptx_version is None: - _, cuda_version = path_to_ptxas() - ptx_version = ptx_get_version(cuda_version) - ptx_version = f'{ptx_version//10}.{ptx_version%10}' - ret = re.sub(r'\.version \d+\.\d+', f'.version {ptx_version}', ret, flags=re.MULTILINE) - # Remove the debug flag that prevents ptxas from optimizing the code - ret = re.sub(r",\s*debug|debug,\s*", "", ret) - return ret - - @staticmethod - def make_cubin(src, metadata, opt, capability): - ptxas, _ = path_to_ptxas() - with tempfile.NamedTemporaryFile(delete=False, mode='w', suffix='.ptx') as fsrc, \ - tempfile.NamedTemporaryFile(delete=False, mode='r', suffix='.log') as flog: - fsrc.write(src) - fsrc.flush() - fbin = fsrc.name + '.o' - - line_info = '' if os.environ.get('TRITON_DISABLE_LINE_INFO') else ' -lineinfo' - fmad = '' if opt.enable_fp_fusion else ' --fmad=false' - suffix = 'a ' if capability == 90 else ' ' - cmd = f'{ptxas}{line_info}{fmad} -v --gpu-name=sm_{capability}{suffix}{fsrc.name} -o {fbin} 2> {flog.name}' - - try: - subprocess.run(cmd, shell=True, check=True) - except subprocess.CalledProcessError as e: - with open(flog.name) as log_file: - log = log_file.read() - if e.returncode == 255: - raise RuntimeError(f'Internal Triton PTX codegen error: \n{log}') - elif e.returncode == 128 + signal.SIGSEGV: - raise RuntimeError( - f'Please run `ptxas {fsrc.name}` to confirm that this is a bug in `ptxas`\n{log}') - else: - raise RuntimeError(f'`ptxas` failed with error code {e.returncode}: \n{log}') - finally: - if os.path.exists(fsrc.name): - os.remove(fsrc.name) - if os.path.exists(flog.name): - os.remove(flog.name) - - with open(fbin, 'rb') as f: - cubin = f.read() - if os.path.exists(fbin): - os.remove(fbin) - return cubin - - def add_stages(self, stages, options): - stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options) - stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, self.capability) - stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, self.capability) - stages["ptx"] = lambda src, metadata: self.make_ptx(src, metadata, options, self.capability) - stages["cubin"] = lambda src, metadata: self.make_cubin(src, metadata, options, self.capability) - - def hash(self): - return f'{get_cuda_version_key()}-{self.capability}' - - def make_launcher_stub(self, src, metadata): - ids = { - "ids_of_tensormaps": metadata.get("ids_of_tensormaps", tuple()), "ids_of_folded_args": - metadata.get("ids_of_folded_args", - tuple()), "ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple() - } - constants = src.constants if hasattr(src, "constants") else dict() - enable_warp_specialization = False - - # set constant - return make_stub(src.name, src.signature, constants, ids, enable_warp_specialization=enable_warp_specialization) - - @classmethod - def create_backend(cls, device_type: str): - return cls(device_type) diff --git a/python/triton/compiler/backends/xpu.py b/python/triton/compiler/backends/xpu.py deleted file mode 100644 index f9783dcadb..0000000000 --- a/python/triton/compiler/backends/xpu.py +++ /dev/null @@ -1,230 +0,0 @@ -from triton.common.backend import BaseBackend -from dataclasses import dataclass -from ...common.backend import get_cuda_version_key -from ..._C.libtriton import ir, passes, nvidia, llvm -import functools -from typing import Any -from ..make_launcher import make_stub -from ..utils import get_ids_of_tensormaps, parse_tma_info -import hashlib -import os -from pathlib import Path - - -@functools.lru_cache() -def ptx_get_version(cuda_version) -> int: - ''' - Get the highest PTX version supported by the current CUDA driver. - ''' - assert isinstance(cuda_version, str) - major, minor = map(int, cuda_version.split('.')) - if major == 12: - return 80 + minor - if major == 11: - return 70 + minor - if major == 10: - return 63 + minor - raise RuntimeError("Triton only support CUDA 10.0 or higher") - - -@dataclass(frozen=True) -class XPUOptions: - num_warps: int = 4 - num_ctas: int = 1 - num_stages: int = 2 - cluster_dims: tuple = (1, 1, 1) - ptx_version: int = None - enable_warp_specialization: bool = False - enable_persistent: bool = False - optimize_epilogue: bool = False - enable_fp_fusion: bool = True - allow_fp8e4nv: bool = False - max_num_imprecise_acc_default: bool = None - extern_libs: dict = None - debug: bool = False - - def __post_init__(self): - default_libdir = Path(__file__).parent.parent.parent / 'third_party' / 'sycl' / 'lib' - extern_libs = dict() if self.extern_libs is None else dict(self.extern_libs) - if not extern_libs.get('libdevice', None): - extern_libs['libdevice'] = str(default_libdir / 'libsycl-spir64-unknown-unknown.bc') - object.__setattr__(self, 'extern_libs', tuple(extern_libs.items())) - assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \ - "num_warps must be a power of 2" - - def hash(self): - key = '_'.join([f'{name}-{val}' for name, val in self.__dict__.items()]) - return hashlib.md5(key.encode("utf-8")).hexdigest() - - -class XPUBackend(BaseBackend): - - def __init__(self, device_type: tuple) -> None: - super().__init__(device_type) - self.capability = 0 - assert isinstance(self.capability, int) - - def parse_options(self, opts) -> Any: - args = {k: opts[k] for k in XPUOptions.__dataclass_fields__.keys() if k in opts} - args["allow_fp8e4nv"] = True - args["max_num_imprecise_acc_default"] = 2**30 if self.capability == 90 else 0 - return XPUOptions(**args) - - @staticmethod - def load_dialects(ctx): - nvidia.load_dialects(ctx) - - @staticmethod - def make_ttir(mod, metadata, opt): - pm = ir.pass_manager(mod.context) - pm.enable_debug() - passes.common.add_inliner(pm) - passes.ttir.add_combine(pm) - passes.common.add_canonicalizer(pm) - passes.ttir.add_reorder_broadcast(pm) - passes.common.add_cse(pm) - passes.common.add_licm(pm) - passes.common.add_symbol_dce(pm) - pm.run(mod) - return mod - - @staticmethod - def make_ttgir(mod, metadata, opt, capability): - cluster_info = nvidia.ClusterInfo() - if opt.cluster_dims is not None: - cluster_info.clusterDimX = opt.cluster_dims[0] - cluster_info.clusterDimY = opt.cluster_dims[1] - cluster_info.clusterDimZ = opt.cluster_dims[2] - # TTIR -> TTGIR - pm = ir.pass_manager(mod.context) - pm.enable_debug() - passes.ttir.add_convert_to_ttgpuir(pm, opt.num_warps, 32, opt.num_ctas, capability) - # optimize TTGIR - passes.ttgpuir.add_coalesce(pm) - # TODO(Qingyi): Move PlanCTAPass to the front of CoalescePass - nvidia.passes.ttnvgpuir.add_plan_cta(pm, cluster_info) - nvidia.passes.ttgpuir.add_rewrite_tensor_pointer(pm, capability) - nvidia.passes.ttnvgpuir.add_plan_cta(pm, cluster_info) - passes.ttgpuir.add_remove_layout_conversions(pm) - passes.ttgpuir.add_optimize_thread_locality(pm) - passes.ttgpuir.add_accelerate_matmul(pm, capability) - passes.ttgpuir.add_remove_layout_conversions(pm) - if opt.optimize_epilogue: - passes.ttgpuir.add_optimize_epilogue(pm) - passes.ttgpuir.add_optimize_dot_operands(pm) - passes.common.add_cse(pm) - # `num_warps` does not mean the total number of warps of a CTA when - # warp specialization is enabled. - # it's the responsibility of the compiler to figure out the exact - # `num_warps` to use. - # TODO: support the case where `num_warps` from user is not 4. - ws_enabled = False - if capability // 10 >= 9 and opt.enable_warp_specialization and opt.num_warps == 4: - nvidia.passes.ttnvgpuir.add_wsfeasibility_checking(pm, capability) - pm.run(mod) - ws_enabled = nvidia.passes.ttnvgpuir.is_ws_supported(mod) - pm = ir.pass_manager(mod.context) - pm.enable_debug() - metadata["ws_enabled"] = ws_enabled - if ws_enabled: - nvidia.passes.ttnvgpuir.add_wsdecomposing(pm, capability) - nvidia.passes.ttnvgpuir.add_wspipeline(pm, opt.num_stages, opt.num_warps, capability) - nvidia.passes.ttnvgpuir.add_wsmutex(pm, capability) - nvidia.passes.ttnvgpuir.add_wsmaterialization(pm, capability) - passes.common.add_licm(pm) - passes.common.add_cse(pm) - else: - passes.ttgpuir.add_pipeline(pm, opt.num_stages, opt.num_warps, opt.num_ctas, capability) - nvidia.passes.ttnvgpuir.add_materialize_load_store(pm, opt.num_warps, capability) - if capability // 10 <= 8: - passes.ttgpuir.add_prefetch(pm) - passes.ttgpuir.add_optimize_dot_operands(pm) - passes.ttgpuir.add_remove_layout_conversions(pm) - passes.ttgpuir.add_decompose_conversions(pm) - nvidia.passes.ttnvgpuir.add_wsfixup_missing_attrs(pm) - passes.ttgpuir.add_reorder_instructions(pm) - passes.common.add_cse(pm) - passes.common.add_symbol_dce(pm) - if capability // 10 >= 9: - nvidia.passes.ttnvgpuir.add_fence_insertion(pm) - nvidia.passes.ttnvgpuir.add_wsfixup_missing_attrs(pm) - passes.common.add_canonicalizer(pm) - pm.run(mod) - metadata["cluster_dims"] = (cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ) - return mod - - @staticmethod - def make_llir(src, metadata, options, capability): - # warp-specialization mutates num_warps - num_warp_groups = src.get_int_attr("triton_gpu.num-warp-groups-per-cta") - if num_warp_groups is not None: - metadata["num_warps"] *= num_warp_groups - mod = src - # TritonGPU -> LLVM-IR (MLIR) - tma_infos = nvidia.TMAInfos() - pm = ir.pass_manager(mod.context) - pm.enable_debug() - passes.convert.add_scf_to_cf(pm) - passes.convert.add_index_to_llvmir(pm) - nvidia.passes.ttgpuir.add_to_llvmir(pm, capability, tma_infos) - if metadata["ws_enabled"]: - passes.common.add_licm(pm) - passes.common.add_cse(pm) - nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm) - passes.convert.add_arith_to_llvmir(pm) - passes.common.add_canonicalizer(pm) - passes.common.add_cse(pm) - passes.common.add_symbol_dce(pm) - if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0": - passes.llvmir.add_di_scope(pm) - pm.run(mod) - # LLVM-IR (MLIR) -> LLVM-IR (LLVM) - context = llvm.context() - llvm_mod = llvm.to_module(mod, context) - llvm.set_spv_target_triple(llvm_mod) - if options.extern_libs: - for name, path in options.extern_libs: - llvm.link_extern_lib(llvm_mod, path) - llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3) - # Get some metadata - if len(tma_infos) > 0: - metadata["tensormaps_info"] = parse_tma_info(tma_infos, metadata["ids_of_folded_args"]) - for i, _ in enumerate(metadata["tensormaps_info"]): - metadata["tensormaps_info"][i].ids_of_folded_args = metadata["ids_of_folded_args"] - metadata["ids_of_tensormaps"] = get_ids_of_tensormaps(metadata.get("tensormaps_info", None)) - metadata["shared"] = src.get_int_attr("triton_gpu.shared") - ret = str(llvm_mod) - del llvm_mod - del context - return ret - - @staticmethod - def make_spv(src, metadata): - ret, name = llvm.translate_to_spirv(src) - metadata["name"] = name - return ret - - def add_stages(self, stages, options): - stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options) - stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, self.capability) - stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, self.capability) - stages["spv"] = lambda src, metadata: self.make_spv(src, metadata) - - def hash(self): - return f'{get_cuda_version_key()}-{self.capability}' - - def make_launcher_stub(self, src, metadata): - ids = { - "ids_of_tensormaps": metadata.get("ids_of_tensormaps", tuple()), "ids_of_folded_args": - metadata.get("ids_of_folded_args", - tuple()), "ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple() - } - constants = src.constants if hasattr(src, "constants") else dict() - enable_warp_specialization = False - - # set constant - return make_stub(src.name, src.signature, constants, ids, enable_warp_specialization=enable_warp_specialization) - - @classmethod - def create_backend(cls, device_type: str): - return cls(device_type) diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 4a012922b8..a89c5278bf 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -1,21 +1,21 @@ from __future__ import annotations - import hashlib import json - from .._C.libtriton import get_env_vars, ir -# from ..runtime import driver, jit, JITFunction -# TODO: runtime.errors +from ..backends import backends +from .. import __version__ from ..runtime.autotuner import OutOfResources from ..runtime.cache import get_cache_manager from ..runtime.driver import driver from ..runtime.jit import (get_dev_ctxt_queue_objs, get_event_pool, get_imm_cmd_list) -from .utils import InfoFromBackendForTensorMap -from .backends import make_backend +# TODO: this shouldn't be here +from ..backends.xpu.compiler import InfoFromBackendForTensorMap from dataclasses import dataclass from .code_generator import ast_to_ttir from pathlib import Path import re +import functools +import os @dataclass @@ -155,6 +155,37 @@ def parse_options(self): return dict() +@functools.lru_cache() +def triton_key(): + import pkgutil + TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + contents = [] + # frontend + with open(__file__, "rb") as f: + contents += [hashlib.sha1(f.read()).hexdigest()] + # compiler + compiler_path = os.path.join(TRITON_PATH, 'compiler') + backends_path = os.path.join(TRITON_PATH, 'compiler', 'backends') + for lib in pkgutil.iter_modules([compiler_path, backends_path]): + with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: + contents += [hashlib.sha1(f.read()).hexdigest()] + # backend + libtriton_hash = hashlib.sha1() + with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f: + while True: + chunk = f.read(1024**2) + if not chunk: + break + libtriton_hash.update(chunk) + contents.append(libtriton_hash.hexdigest()) + # language + language_path = os.path.join(TRITON_PATH, 'language') + for lib in pkgutil.iter_modules([language_path]): + with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: + contents += [hashlib.sha1(f.read()).hexdigest()] + return f'{__version__}' + '-'.join(contents) + + def compile(src, target=None, options=None): if target is None: target = driver.get_current_target() @@ -166,7 +197,7 @@ def compile(src, target=None, options=None): extra_options = src.parse_options() options = backend.parse_options(dict(options or dict(), **extra_options)) # create cache manager - key = f"{src.hash()}-{backend.hash()}-{options.hash()}-{str(sorted(get_env_vars().items()))}" + key = f"{triton_key()}-{src.hash()}-{backend.hash()}-{options.hash()}-{str(sorted(get_env_vars().items()))}" hash = hashlib.md5(key.encode("utf-8")).hexdigest() fn_cache_manager = get_cache_manager(hash) metadata_filename = f"{src.name}.json" @@ -175,8 +206,7 @@ def compile(src, target=None, options=None): if metadata_path is not None: # cache hit! metadata = json.loads(Path(metadata_path).read_text()) - so_path = backend.make_launcher_stub(src, metadata) - return CompiledKernel(so_path, metadata_group) + return CompiledKernel(src, metadata_group) # initialize metadata metadata = { "target": target, @@ -200,9 +230,16 @@ def compile(src, target=None, options=None): metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename, binary=False) fn_cache_manager.put_group(metadata_filename, metadata_group) - so_path = backend.make_launcher_stub(src, metadata) # return handle to compiled kernel - return CompiledKernel(so_path, metadata_group) + return CompiledKernel(src, metadata_group) + + +def make_backend(target): + actives = [x.compiler for x in backends.values() if x.compiler.supports_target(target)] + if len(actives) != 1: + raise RuntimeError( + f"{len(actives)} compatible backends for target ({target[0]}) ({actives}). There should only be one.") + return actives[0](target) class CompiledKernel: @@ -212,15 +249,8 @@ class CompiledKernel: launch_enter_hook = None launch_exit_hook = None - def __init__(self, so_path, metadata_group): + def __init__(self, src, metadata_group): metadata_path = next((Path(p) for c, p in metadata_group.items() if c.endswith(".json"))) - # initialize launcher - import importlib.util - spec = importlib.util.spec_from_file_location("__triton_launcher", so_path) - mod = importlib.util.module_from_spec(spec) - spec.loader.exec_module(mod) - self.run = getattr(mod, "launch") - # initialize metadata self.metadata = json.loads(metadata_path.read_text()) self.metadata['tensormaps_info'] = [InfoFromBackendForTensorMap(e) for e in self.metadata['tensormaps_info'] ] if 'tensormaps_info' in self.metadata else [] @@ -229,6 +259,8 @@ def __init__(self, so_path, metadata_group): self.name = self.metadata["name"] for key, val in self.metadata.items(): setattr(self, key, val) + # create launcher + self.run = driver.launcher_cls(src, self.metadata) # stores the text of each level of IR that was generated during compilation asm_files = [Path(p) for c, p in metadata_group.items() if not c.endswith(".json")] self.asm = { diff --git a/python/triton/compiler/make_launcher.py b/python/triton/compiler/make_launcher.py index 14ccdc1a0a..e69de29bb2 100644 --- a/python/triton/compiler/make_launcher.py +++ b/python/triton/compiler/make_launcher.py @@ -1,551 +0,0 @@ -import hashlib -import os -import tempfile - -from ..common import _build -from ..common.backend import get_cuda_version_key -from ..common.build import is_hip -from ..runtime.cache import get_cache_manager -from .utils import generate_cu_signature - - -def is_spirv(): - return os.environ.get("TRITON_TARGET_NVVM", "0") != "1" - - -# ----- stub -------- - - -def make_so_cache_key(version_hash, signature, constants, ids, **kwargs): - # Get unique key for the compiled code - signature = {k: 'ptr' if v[0] == '*' else v for k, v in signature.items()} - key = f"{version_hash}-{''.join(signature.values())}-{constants}-{ids}" - for kw in kwargs: - key = f"{key}-{kwargs.get(kw)}" - key = hashlib.md5(key.encode("utf-8")).hexdigest() - return key - - -def make_stub(name, signature, constants, ids, **kwargs): - # name of files that are cached - so_cache_key = make_so_cache_key(get_cuda_version_key(), signature, constants, ids, **kwargs) - so_cache_manager = get_cache_manager(so_cache_key) - so_name = f"{name}.so" - # retrieve stub from cache if it exists - cache_path = so_cache_manager.get_file(so_name) - if cache_path is None: - with tempfile.TemporaryDirectory() as tmpdir: - src = generate_launcher(constants, signature, ids) - src_path = os.path.join(tmpdir, "main.cpp" if is_spirv() else "main.c") - with open(src_path, "w") as f: - f.write(src) - so = _build(name, src_path, tmpdir) - with open(so, "rb") as f: - return so_cache_manager.put(f.read(), so_name, binary=True) - else: - return cache_path - - -# ----- source code generation -------- - - -def ty_to_cpp(ty): - if ty[0] == '*': - return "void*" if is_spirv() else "hipDeviceptr_t" if is_hip() else "CUdeviceptr" - return { - "i1": "int32_t", - "i8": "int8_t", - "i16": "int16_t", - "i32": "int32_t", - "i64": "int64_t", - "u32": "uint32_t", - "u64": "uint64_t", - "fp16": "float", - "bf16": "float", - "fp32": "float", - "f32": "float", - "fp64": "double", - }[ty] - - -def generate_launcher(constants, signature, ids): - # Record the end of regular arguments; - # subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA. - signature, desc_start_idx = generate_cu_signature(constants, signature, ids) - arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) - - def _extracted_type(ty): - if ty[0] == '*': - return "void*" if is_spirv() else "PyObject*" - return { - 'i1': 'int32_t', - 'i32': 'int32_t', - 'i64': 'int64_t', - 'u32': 'uint32_t', - 'u64': 'uint64_t', - 'fp16': 'float', - 'bf16': 'float', - 'fp32': 'float', - 'f32': 'float', - 'fp64': 'double', - }[ty] - - def format_of(ty): - return { - "PyObject*": "O", - "void*": "K", - "float": "f", - "double": "d", - "long": "l", - "uint32_t": "I", - "int32_t": "i", - "uint64_t": "K", - "int64_t": "L", - }[ty] - - format = "iiiiiiiiiKKOOOK" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) - if is_spirv(): - reg_launch_format = "iiiiiiiiiiKKKKKOOOK" + ''.join( - [format_of(_extracted_type(ty)) for ty in signature.values()]) - - # generate glue code - if is_spirv(): - src = f""" - #include - #include - #include - #include - #include - - #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION - #include - #include - #include - - static inline void gpuAssert(ze_result_t code, const char *file, int line) - {{ - if (code != ZE_RESULT_SUCCESS) - {{ - const char* prefix = "Triton Error [ZE]: "; - std::string str = std::to_string(code); - char err[1024] = {{0}}; - strcat(err, prefix); - strcat(err, str.c_str()); - PyErr_SetString(PyExc_RuntimeError, err); - }} - }} - - #define ZE_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }} - - static void _regular_launch(uint32_t gridX, uint32_t gridY, uint32_t gridZ, int num_warps, int shared_memory, - ze_command_queue_handle_t queue, ze_device_handle_t _dev, ze_context_handle_t _ctxt, - ze_kernel_handle_t function, ze_event_pool_handle_t event_pool - {', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ - void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }}; - - if (gridX*gridY*gridZ > 0) {{ - {" ".join(f'zeKernelSetArgumentValue(function, {idx}, sizeof({ty_to_cpp(item)}), params[{idx}]);' for idx, item in enumerate([signature[i] for i in signature if i not in constants]))} - if (shared_memory) {{ - uint32_t num_params = sizeof(params)/sizeof(params[0]); - zeKernelSetArgumentValue(function, num_params, shared_memory, NULL); - }} - zeKernelSetGroupSize(function, 32*num_warps, 1, 1); - - ze_group_count_t grpCount = {{gridX, gridY, gridZ}}; - - // Create command list - ze_command_list_handle_t CmdList; - ze_command_list_desc_t CommandListDesc_ = {{ - ZE_STRUCTURE_TYPE_COMMAND_LIST_DESC, - nullptr, - 0, - 0, - }}; - - ZE_CHECK(zeCommandListCreate(_ctxt, _dev, &CommandListDesc_, &CmdList)); - - ze_event_desc_t eventDesc = {{ - ZE_STRUCTURE_TYPE_EVENT_DESC, - nullptr, - 0, - 0, - ZE_EVENT_SCOPE_FLAG_HOST - }}; - ze_event_handle_t hEvent; - ZE_CHECK(zeEventCreate(event_pool, &eventDesc, &hEvent)); - - // Append a signal of an event into the command list after the kernel executes - ZE_CHECK(zeCommandListAppendLaunchKernel(CmdList, function, &grpCount, hEvent, 0, nullptr)); - - // close command list - ZE_CHECK(zeCommandListClose(CmdList)); - - // FIXME: The following statement currently doesn't synchronize all IPEX SYCL queues. - // Needs to find all IPEX SYCL queues - // Synchronize the command queue to ensure previous IPEX SYCL commands complete before Triton kernel starts - // ZE_CHECK(zeCommandQueueSynchronize(queue, std::numeric_limits::max())); - - // execute command list - ZE_CHECK(zeCommandQueueExecuteCommandLists(queue, 1, &CmdList, nullptr)); - - // Wait on event to complete - ZE_CHECK(zeEventHostSynchronize(hEvent, std::numeric_limits::max())); - }} - }} - - static void _launch(uint32_t gridX, uint32_t gridY, uint32_t gridZ, int num_warps, int shared_memory, - ze_command_list_handle_t queue, ze_kernel_handle_t function, ze_event_pool_handle_t event_pool - {', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ - void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }}; - - if (gridX*gridY*gridZ > 0) {{ - {" ".join(f'zeKernelSetArgumentValue(function, {idx}, sizeof({ty_to_cpp(item)}), params[{idx}]);' for idx, item in enumerate([signature[i] for i in signature if i not in constants]))} - if (shared_memory) {{ - uint32_t num_params = sizeof(params)/sizeof(params[0]); - zeKernelSetArgumentValue(function, num_params, shared_memory, NULL); - }} - zeKernelSetGroupSize(function, 32*num_warps, 1, 1); - ze_group_count_t grpCount = {{gridX, gridY, gridZ}}; - - ze_event_desc_t eventDesc = {{ - ZE_STRUCTURE_TYPE_EVENT_DESC, - nullptr, - 0, - 0, - ZE_EVENT_SCOPE_FLAG_HOST - }}; - ze_event_handle_t hEvent; - ZE_CHECK(zeEventCreate(event_pool, &eventDesc, &hEvent)); - - // FIXME: The following statement currently doesn't synchronize all IPEX SYCL queues. - // Needs to find all IPEX SYCL queues - // Synchronize to ensure previous IPEX SYCL commands complete before Triton kernel starts - ZE_CHECK(zeCommandListHostSynchronize(queue, std::numeric_limits::max())); - - // Append a signal of an event into the command list after the kernel executes - ZE_CHECK(zeCommandListAppendLaunchKernel(queue, function, &grpCount, hEvent, 0, nullptr)); - // Wait on event to complete - ZE_CHECK(zeEventHostSynchronize(hEvent, std::numeric_limits::max())); - }} - }} - - typedef struct _DevicePtrInfo {{ - void* dev_ptr; - bool valid; - }} DevicePtrInfo; - - static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{ - DevicePtrInfo ptr_info; - ptr_info.dev_ptr = 0; - ptr_info.valid = true; - PyTypeObject* obj_type = Py_TYPE(obj); - - if (PyLong_Check(obj)) {{ - ptr_info.dev_ptr = (void*) PyLong_AsUnsignedLongLong(obj); - return ptr_info; - }} - if (obj == Py_None) {{ - // valid nullptr - return ptr_info; - }} - PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); - if(ptr){{ - PyObject *empty_tuple = PyTuple_New(0); - PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); - Py_DECREF(empty_tuple); - Py_DECREF(ptr); - if (!PyLong_Check(ret)) {{ - PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); - ptr_info.valid = false; - return ptr_info; - }} - ptr_info.dev_ptr = (void*) PyLong_AsUnsignedLongLong(ret); - if(!ptr_info.dev_ptr) {{ - return ptr_info; - }} - Py_DECREF(ret); // Thanks ChatGPT! - return ptr_info; - }} - PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); - return ptr_info; - }} - - static PyObject* launch(PyObject* self, PyObject* args) {{ - - int gridX, gridY, gridZ; - uint64_t _queue; - uint64_t _stream; - uint64_t _function; - uint64_t _event_pool; - uint64_t _dev; - uint64_t _ctxt; - int num_warps; - int num_ctas; - int clusterDimX; - int clusterDimY; - int clusterDimZ; - int _is_icl; - int shared_memory; - PyObject *launch_enter_hook = NULL; - PyObject *launch_exit_hook = NULL; - PyObject *compiled_kernel = NULL; - - {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} - if (!PyArg_ParseTuple(args, \"{reg_launch_format}\", &gridX, &gridY, &gridZ, &num_warps, &num_ctas, - &clusterDimX, &clusterDimY, &clusterDimZ, &shared_memory, &_is_icl, &_stream, - &_queue, &_dev, &_ctxt, &_function, &launch_enter_hook, &launch_exit_hook, - &compiled_kernel, &_event_pool - {', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''})) {{ - return NULL; - }} - - if (launch_enter_hook != Py_None) {{ - PyObject_CallObject(launch_enter_hook, args); - }} - - // raise exception asap - // {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; - if (_is_icl == 0) {{ - _regular_launch(gridX, gridY, gridZ, num_warps, shared_memory, (ze_command_queue_handle_t)_queue, - (ze_device_handle_t)_dev, (ze_context_handle_t)_ctxt, (ze_kernel_handle_t)_function, - (ze_event_pool_handle_t)_event_pool - {', ' + ', '.join(f"(void *) _arg{i}" if ty[0]=="*" else f"_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''}); - }} else {{ - _launch(gridX, gridY, gridZ, num_warps, shared_memory, (ze_command_list_handle_t)_stream, - (ze_kernel_handle_t)_function, (ze_event_pool_handle_t)_event_pool - {', ' + ', '.join(f"(void *) _arg{i}" if ty[0]=="*" else f"_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''}); - }} - - if (launch_exit_hook != Py_None) {{ - PyObject_CallObject(launch_exit_hook, args); - }} - if (PyErr_Occurred()) {{ - return NULL; - }} - - // return None - Py_INCREF(Py_None); - return Py_None; - }} - - static PyMethodDef ModuleMethods[] = {{ - {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, - {{NULL, NULL, 0, NULL}} // sentinel - }}; - - static struct PyModuleDef ModuleDef = {{ - PyModuleDef_HEAD_INIT, - \"__triton_launcher\", - NULL, //documentation - -1, //size - ModuleMethods - }}; - - PyMODINIT_FUNC PyInit___triton_launcher(void) {{ - PyObject *m = PyModule_Create(&ModuleDef); - if(m == NULL) {{ - return NULL; - }} - PyModule_AddFunctions(m, ModuleMethods); - return m; - }} - """ - else: - folded_without_constexprs = [c for c in ids['ids_of_folded_args'] if c not in ids['ids_of_const_exprs']] - params = [ - i for i in signature.keys() - if i >= desc_start_idx or (i not in constants and i not in folded_without_constexprs) - ] - src = f""" -#include \"cuda.h\" -#include -#include -#include - -static inline void gpuAssert(CUresult code, const char *file, int line) -{{ - if (code != CUDA_SUCCESS) - {{ - const char* prefix = "Triton Error [CUDA]: "; - const char* str; - cuGetErrorString(code, &str); - char err[1024] = {{0}}; - strcat(err, prefix); - strcat(err, str); - PyGILState_STATE gil_state; - gil_state = PyGILState_Ensure(); - PyErr_SetString(PyExc_RuntimeError, err); - PyGILState_Release(gil_state); - }} -}} - -#define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }} - -typedef CUresult (*cuLaunchKernelEx_t)(const CUlaunchConfig* config, CUfunction f, void** kernelParams, void** extra); - -static cuLaunchKernelEx_t getLaunchKernelExHandle() {{ - // Open the shared library - void* handle = dlopen("libcuda.so", RTLD_LAZY); - if (!handle) {{ - PyErr_SetString(PyExc_RuntimeError, "Failed to open libcuda.so"); - return NULL; - }} - // Clear any existing error - dlerror(); - cuLaunchKernelEx_t cuLaunchKernelExHandle = (cuLaunchKernelEx_t)dlsym(handle, "cuLaunchKernelEx"); - // Check for errors - const char *dlsym_error = dlerror(); - if (dlsym_error) {{ - PyErr_SetString(PyExc_RuntimeError, "Failed to retrieve cuLaunchKernelEx from libcuda.so"); - return NULL; - }} - return cuLaunchKernelExHandle; -}} - -static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ - void *params[] = {{ {', '.join(f"&arg{i}" for i in params)} }}; - if (gridX*gridY*gridZ > 0) {{ - if (num_ctas == 1) {{ - CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0)); - }} else {{ - CUlaunchAttribute launchAttr[2]; - launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; - launchAttr[0].value.clusterDim.x = clusterDimX; - launchAttr[0].value.clusterDim.y = clusterDimY; - launchAttr[0].value.clusterDim.z = clusterDimZ; - launchAttr[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE; - launchAttr[1].value.clusterSchedulingPolicyPreference = CU_CLUSTER_SCHEDULING_POLICY_SPREAD; - CUlaunchConfig config; - config.gridDimX = gridX * clusterDimX; - config.gridDimY = gridY * clusterDimY; - config.gridDimZ = gridZ * clusterDimZ; - config.blockDimX = 32 * num_warps; - config.blockDimY = 1; - config.blockDimZ = 1; - config.sharedMemBytes = shared_memory; - config.hStream = stream; - config.attrs = launchAttr; - config.numAttrs = 2; - static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL; - if (cuLaunchKernelExHandle == NULL) {{ - cuLaunchKernelExHandle = getLaunchKernelExHandle(); - }} - CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0)); - }} - }} -}} - -typedef struct _DevicePtrInfo {{ - CUdeviceptr dev_ptr; - bool valid; -}} DevicePtrInfo; - -static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{ - DevicePtrInfo ptr_info; - ptr_info.dev_ptr = 0; - ptr_info.valid = true; - if (PyLong_Check(obj)) {{ - ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(obj); - return ptr_info; - }} - if (obj == Py_None) {{ - // valid nullptr - return ptr_info; - }} - PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); - if(ptr){{ - PyObject *empty_tuple = PyTuple_New(0); - PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); - Py_DECREF(empty_tuple); - Py_DECREF(ptr); - if (!PyLong_Check(ret)) {{ - PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); - ptr_info.valid = false; - return ptr_info; - }} - ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(ret); - if(!ptr_info.dev_ptr) - return ptr_info; - uint64_t dev_ptr; - int status = cuPointerGetAttribute(&dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr); - if (status == CUDA_ERROR_INVALID_VALUE) {{ - PyErr_Format(PyExc_ValueError, - "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx); - ptr_info.valid = false; - }} - ptr_info.dev_ptr = dev_ptr; - Py_DECREF(ret); // Thanks ChatGPT! - return ptr_info; - }} - PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); - ptr_info.valid = false; - return ptr_info; -}} - -static PyObject* launch(PyObject* self, PyObject* args) {{ - int gridX, gridY, gridZ; - uint64_t _stream; - uint64_t _function; - int num_warps; - int num_ctas; - int clusterDimX; - int clusterDimY; - int clusterDimZ; - int shared_memory; - PyObject *launch_enter_hook = NULL; - PyObject *launch_exit_hook = NULL; - PyObject *compiled_kernel = NULL; - {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} - if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &num_ctas, - &clusterDimX, &clusterDimY, &clusterDimZ, &shared_memory, &_stream, - &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel - {', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''})) {{ - return NULL; - }} - - if (launch_enter_hook != Py_None && !PyObject_CallObject(launch_enter_hook, args)) {{ - return NULL; - }} - - - // raise exception asap - {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; - Py_BEGIN_ALLOW_THREADS; - _launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function{', ' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items()) if len(signature) > 0 else ''}); - Py_END_ALLOW_THREADS; - if (PyErr_Occurred()) {{ - return NULL; - }} - - if (launch_exit_hook != Py_None && !PyObject_CallObject(launch_exit_hook, args)) {{ - return NULL; - }} - - // return None - Py_INCREF(Py_None); - return Py_None; -}} - -static PyMethodDef ModuleMethods[] = {{ - {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, - {{NULL, NULL, 0, NULL}} // sentinel -}}; - -static struct PyModuleDef ModuleDef = {{ - PyModuleDef_HEAD_INIT, - \"__triton_launcher\", - NULL, //documentation - -1, //size - ModuleMethods -}}; - -PyMODINIT_FUNC PyInit___triton_launcher(void) {{ - PyObject *m = PyModule_Create(&ModuleDef); - if(m == NULL) {{ - return NULL; - }} - PyModule_AddFunctions(m, ModuleMethods); - return m; -}} -""" - return src diff --git a/python/triton/language/math.py b/python/triton/language/math.py index 4277c78a95..7b3d3366e5 100644 --- a/python/triton/language/math.py +++ b/python/triton/language/math.py @@ -1,12 +1,13 @@ -import functools -import os from enum import IntEnum -from ..common.build import is_spirv -from ..common.build import is_hip from . import core +def is_spirv(): + import torch + return torch.xpu.is_available() + + class PropagateNan(IntEnum): """ PropagateNan is an enum class that specifies how NaNs are handled in min/max operations. @@ -17,30 +18,17 @@ class PropagateNan(IntEnum): NONE = 0x00000000 -@functools.lru_cache() -def libdevice_path(): - third_party_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "third_party") - if is_spirv(): - default = os.path.join(third_party_dir, "sycl", "lib", "libsycl-spir64-unknown-unknown.bc") - elif is_hip(): - default = os.path.join(third_party_dir, "hip", "lib", "bitcode", "cuda2gcn.bc") - else: - default = os.path.join(third_party_dir, "cuda", "lib", "libdevice.10.bc") - - return os.getenv("TRITON_LIBDEVICE_PATH", default) - - @core.extern def clz(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("int32"), ): ("__imf_clz", core.dtype("int32")), (core.dtype("int64"), ): ("__imf_clzll", core.dtype("int32")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("int32"), ): ("__nv_clz", core.dtype("int32")), (core.dtype("int64"), ): ("__nv_clzll", core.dtype("int32")), }, is_pure=True, _builder=_builder) @@ -50,13 +38,13 @@ def clz(arg0, _builder=None): def popc(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("int32"), ): ("__imf_popc", core.dtype("int32")), (core.dtype("int64"), ): ("__imf_popcll", core.dtype("int32")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("int32"), ): ("__nv_popc", core.dtype("int32")), (core.dtype("int64"), ): ("__nv_popcll", core.dtype("int32")), }, is_pure=True, _builder=_builder) @@ -65,11 +53,11 @@ def popc(arg0, _builder=None): @core.extern def byte_perm(arg0, arg1, arg2, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, arg2], { + return core.extern_elementwise("", "", [arg0, arg1, arg2], { (core.dtype("int32"), core.dtype("int32"), core.dtype("int32")): ("__imf_byte_perm", core.dtype("int32")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, arg2], { + return core.extern_elementwise("", "", [arg0, arg1, arg2], { (core.dtype("int32"), core.dtype("int32"), core.dtype("int32")): ("__nv_byte_perm", core.dtype("int32")), }, is_pure=True, _builder=_builder) @@ -124,7 +112,7 @@ def max(arg0, arg1, propagate_nan: core.constexpr = PropagateNan.NONE, _builder= def mulhi(arg0, arg1, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1], { + "", "", [arg0, arg1], { (core.dtype("int32"), core.dtype("int32")): ("__imf_mulhi", core.dtype("int32")), (core.dtype("uint32"), core.dtype("uint32")): ("__imf_umulhi", core.dtype("uint32")), (core.dtype("int64"), core.dtype("int64")): ("__imf_mul64hi", core.dtype("int64")), @@ -132,7 +120,7 @@ def mulhi(arg0, arg1, _builder=None): }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1], { + "", "", [arg0, arg1], { (core.dtype("int32"), core.dtype("int32")): ("__nv_mulhi", core.dtype("int32")), (core.dtype("uint32"), core.dtype("uint32")): ("__nv_umulhi", core.dtype("uint32")), (core.dtype("int64"), core.dtype("int64")): ("__nv_mul64hi", core.dtype("int64")), @@ -144,13 +132,13 @@ def mulhi(arg0, arg1, _builder=None): def mul24(arg0, arg1, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1], { + "", "", [arg0, arg1], { (core.dtype("int32"), core.dtype("int32")): ("__imf_mul24", core.dtype("int32")), (core.dtype("uint32"), core.dtype("uint32")): ("__imf_umul24", core.dtype("uint32")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1], { + "", "", [arg0, arg1], { (core.dtype("int32"), core.dtype("int32")): ("__nv_mul24", core.dtype("int32")), (core.dtype("uint32"), core.dtype("uint32")): ("__nv_umul24", core.dtype("uint32")), }, is_pure=True, _builder=_builder) @@ -160,14 +148,14 @@ def mul24(arg0, arg1, _builder=None): def brev(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("int32"), ): ("__imf_brev", core.dtype("int32")), (core.dtype("int64"), ): ("__imf_brevll", core.dtype("int64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("int32"), ): ("__nv_brev", core.dtype("int32")), (core.dtype("int64"), ): ("__nv_brevll", core.dtype("int64")), }, is_pure=True, _builder=_builder) @@ -177,14 +165,14 @@ def brev(arg0, _builder=None): def sad(arg0, arg1, arg2, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1, arg2], { + "", "", [arg0, arg1, arg2], { (core.dtype("int32"), core.dtype("int32"), core.dtype("uint32")): ("__imf_sad", core.dtype("int32")), (core.dtype("uint32"), core.dtype("uint32"), core.dtype("uint32")): ("__imf_usad", core.dtype("uint32")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1, arg2], { + "", "", [arg0, arg1, arg2], { (core.dtype("int32"), core.dtype("int32"), core.dtype("uint32")): ("__nv_sad", core.dtype("int32")), (core.dtype("uint32"), core.dtype("uint32"), core.dtype("uint32")): ("__nv_usad", core.dtype("uint32")), }, is_pure=True, _builder=_builder) @@ -194,7 +182,7 @@ def sad(arg0, arg1, arg2, _builder=None): def abs(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("int32"), ): ("__imf_abs", core.dtype("int32")), (core.dtype("int64"), ): ("__imf_llabs", core.dtype("int64")), (core.dtype("fp32"), ): ("__imf_fabsf", core.dtype("fp32")), @@ -202,7 +190,7 @@ def abs(arg0, _builder=None): }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("int32"), ): ("__nv_abs", core.dtype("int32")), (core.dtype("int64"), ): ("__nv_llabs", core.dtype("int64")), (core.dtype("fp32"), ): ("__nv_fabsf", core.dtype("fp32")), @@ -214,13 +202,13 @@ def abs(arg0, _builder=None): def floor(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__imf_floorf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__imf_floor", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_floorf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_floor", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -228,7 +216,7 @@ def floor(arg0, _builder=None): @core.extern def rcp64h(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp64"), ): ("__nv_rcp64h", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -237,13 +225,13 @@ def rcp64h(arg0, _builder=None): def rsqrt(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__imf_rsqrtf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__imf_rsqrt", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_rsqrtf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_rsqrt", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -253,14 +241,14 @@ def rsqrt(arg0, _builder=None): def ceil(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp64"), ): ("__imf_ceil", core.dtype("fp64")), (core.dtype("fp32"), ): ("__imf_ceilf", core.dtype("fp32")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp64"), ): ("__nv_ceil", core.dtype("fp64")), (core.dtype("fp32"), ): ("__nv_ceilf", core.dtype("fp32")), }, is_pure=True, _builder=_builder) @@ -270,13 +258,13 @@ def ceil(arg0, _builder=None): def trunc(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp64"), ): ("__imf_trunc", core.dtype("fp64")), (core.dtype("fp32"), ): ("__imf_truncf", core.dtype("fp32")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp64"), ): ("__nv_trunc", core.dtype("fp64")), (core.dtype("fp32"), ): ("__nv_truncf", core.dtype("fp32")), }, is_pure=True, _builder=_builder) @@ -286,13 +274,13 @@ def trunc(arg0, _builder=None): def exp2(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__imf_exp2f", core.dtype("fp32")), (core.dtype("fp64"), ): ("__imf_exp2", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_exp2f", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_exp2", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -301,11 +289,11 @@ def exp2(arg0, _builder=None): @core.extern def saturatef(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__imf_saturatef", core.dtype("fp32")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__nv_saturatef", core.dtype("fp32")), }, is_pure=True, _builder=_builder) @@ -313,7 +301,7 @@ def saturatef(arg0, _builder=None): @core.extern def fma_rn(arg0, arg1, arg2, _builder=None): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1, arg2], { + "", "", [arg0, arg1, arg2], { (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf_rn", core.dtype("fp32")), (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma_rn", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -322,7 +310,7 @@ def fma_rn(arg0, arg1, arg2, _builder=None): @core.extern def fma_rz(arg0, arg1, arg2, _builder=None): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1, arg2], { + "", "", [arg0, arg1, arg2], { (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf_rz", core.dtype("fp32")), (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma_rz", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -331,7 +319,7 @@ def fma_rz(arg0, arg1, arg2, _builder=None): @core.extern def fma_rd(arg0, arg1, arg2, _builder=None): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1, arg2], { + "", "", [arg0, arg1, arg2], { (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf_rd", core.dtype("fp32")), (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma_rd", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -340,7 +328,7 @@ def fma_rd(arg0, arg1, arg2, _builder=None): @core.extern def fma_ru(arg0, arg1, arg2, _builder=None): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1, arg2], { + "", "", [arg0, arg1, arg2], { (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf_ru", core.dtype("fp32")), (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma_ru", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -348,7 +336,7 @@ def fma_ru(arg0, arg1, arg2, _builder=None): @core.extern def fast_dividef(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1], { + return core.extern_elementwise("", "", [arg0, arg1], { (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fast_fdividef", core.dtype("fp32")), }, is_pure=True, _builder=_builder) @@ -356,7 +344,7 @@ def fast_dividef(arg0, arg1, _builder=None): @core.extern def div_rn(arg0, arg1, _builder=None): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1], { + "", "", [arg0, arg1], { (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdiv_rn", core.dtype("fp32")), (core.dtype("fp64"), core.dtype("fp64")): ("__nv_ddiv_rn", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -365,7 +353,7 @@ def div_rn(arg0, arg1, _builder=None): @core.extern def div_rz(arg0, arg1, _builder=None): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1], { + "", "", [arg0, arg1], { (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdiv_rz", core.dtype("fp32")), (core.dtype("fp64"), core.dtype("fp64")): ("__nv_ddiv_rz", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -374,7 +362,7 @@ def div_rz(arg0, arg1, _builder=None): @core.extern def div_rd(arg0, arg1, _builder=None): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1], { + "", "", [arg0, arg1], { (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdiv_rd", core.dtype("fp32")), (core.dtype("fp64"), core.dtype("fp64")): ("__nv_ddiv_rd", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -383,7 +371,7 @@ def div_rd(arg0, arg1, _builder=None): @core.extern def div_ru(arg0, arg1, _builder=None): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1], { + "", "", [arg0, arg1], { (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdiv_ru", core.dtype("fp32")), (core.dtype("fp64"), core.dtype("fp64")): ("__nv_ddiv_ru", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -392,7 +380,7 @@ def div_ru(arg0, arg1, _builder=None): @core.extern def rcp_rn(arg0, _builder=None): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_frcp_rn", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_drcp_rn", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -401,7 +389,7 @@ def rcp_rn(arg0, _builder=None): @core.extern def rcp_rz(arg0, _builder=None): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_frcp_rz", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_drcp_rz", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -410,7 +398,7 @@ def rcp_rz(arg0, _builder=None): @core.extern def rcp_rd(arg0, _builder=None): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_frcp_rd", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_drcp_rd", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -419,7 +407,7 @@ def rcp_rd(arg0, _builder=None): @core.extern def rcp_ru(arg0, _builder=None): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_frcp_ru", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_drcp_ru", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -428,7 +416,7 @@ def rcp_ru(arg0, _builder=None): @core.extern def sqrt_rn(arg0, _builder=None): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_fsqrt_rn", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_dsqrt_rn", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -437,7 +425,7 @@ def sqrt_rn(arg0, _builder=None): @core.extern def sqrt_rz(arg0, _builder=None): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_fsqrt_rz", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_dsqrt_rz", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -446,7 +434,7 @@ def sqrt_rz(arg0, _builder=None): @core.extern def sqrt_rd(arg0, _builder=None): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_fsqrt_rd", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_dsqrt_rd", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -455,7 +443,7 @@ def sqrt_rd(arg0, _builder=None): @core.extern def sqrt_ru(arg0, _builder=None): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_fsqrt_ru", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_dsqrt_ru", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -465,13 +453,13 @@ def sqrt_ru(arg0, _builder=None): def sqrt(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__imf_sqrtf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__imf_sqrt", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_sqrtf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_sqrt", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -480,7 +468,7 @@ def sqrt(arg0, _builder=None): @core.extern def add_rn(arg0, arg1, _builder=None): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1], { + "", "", [arg0, arg1], { (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dadd_rn", core.dtype("fp64")), (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fadd_rn", core.dtype("fp32")), }, is_pure=True, _builder=_builder) @@ -489,7 +477,7 @@ def add_rn(arg0, arg1, _builder=None): @core.extern def add_rz(arg0, arg1, _builder=None): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1], { + "", "", [arg0, arg1], { (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dadd_rz", core.dtype("fp64")), (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fadd_rz", core.dtype("fp32")), }, is_pure=True, _builder=_builder) @@ -498,7 +486,7 @@ def add_rz(arg0, arg1, _builder=None): @core.extern def add_rd(arg0, arg1, _builder=None): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1], { + "", "", [arg0, arg1], { (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dadd_rd", core.dtype("fp64")), (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fadd_rd", core.dtype("fp32")), }, is_pure=True, _builder=_builder) @@ -507,7 +495,7 @@ def add_rd(arg0, arg1, _builder=None): @core.extern def add_ru(arg0, arg1, _builder=None): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1], { + "", "", [arg0, arg1], { (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dadd_ru", core.dtype("fp64")), (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fadd_ru", core.dtype("fp32")), }, is_pure=True, _builder=_builder) @@ -516,7 +504,7 @@ def add_ru(arg0, arg1, _builder=None): @core.extern def mul_rn(arg0, arg1, _builder=None): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1], { + "", "", [arg0, arg1], { (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dmul_rn", core.dtype("fp64")), (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmul_rn", core.dtype("fp32")), }, is_pure=True, _builder=_builder) @@ -525,7 +513,7 @@ def mul_rn(arg0, arg1, _builder=None): @core.extern def mul_rz(arg0, arg1, _builder=None): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1], { + "", "", [arg0, arg1], { (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dmul_rz", core.dtype("fp64")), (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmul_rz", core.dtype("fp32")), }, is_pure=True, _builder=_builder) @@ -534,7 +522,7 @@ def mul_rz(arg0, arg1, _builder=None): @core.extern def mul_rd(arg0, arg1, _builder=None): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1], { + "", "", [arg0, arg1], { (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dmul_rd", core.dtype("fp64")), (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmul_rd", core.dtype("fp32")), }, is_pure=True, _builder=_builder) @@ -543,7 +531,7 @@ def mul_rd(arg0, arg1, _builder=None): @core.extern def mul_ru(arg0, arg1, _builder=None): return core.extern_elementwise( - "libdevice", libdevice_path(), [ + "", "", [ arg0, arg1, ], { @@ -561,11 +549,11 @@ def mul_ru(arg0, arg1, _builder=None): @core.extern def double2float_rn(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp64"), ): ("__imf_double2float_rn", core.dtype("fp32")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp64"), ): ("__nv_double2float_rn", core.dtype("fp32")), }, is_pure=True, _builder=_builder) @@ -573,11 +561,11 @@ def double2float_rn(arg0, _builder=None): @core.extern def double2float_rz(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp64"), ): ("__imf_double2float_rz", core.dtype("fp32")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp64"), ): ("__nv_double2float_rz", core.dtype("fp32")), }, is_pure=True, _builder=_builder) @@ -585,11 +573,11 @@ def double2float_rz(arg0, _builder=None): @core.extern def double2float_rd(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp64"), ): ("__imf_double2float_rd", core.dtype("fp32")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp64"), ): ("__nv_double2float_rd", core.dtype("fp32")), }, is_pure=True, _builder=_builder) @@ -597,11 +585,11 @@ def double2float_rd(arg0, _builder=None): @core.extern def double2float_ru(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp64"), ): ("__imf_double2float_ru", core.dtype("fp32")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp64"), ): ("__nv_double2float_ru", core.dtype("fp32")), }, is_pure=True, _builder=_builder) @@ -609,11 +597,11 @@ def double2float_ru(arg0, _builder=None): @core.extern def double2int_rn(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp64"), ): ("__imf_double2int_rn", core.dtype("int32")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp64"), ): ("__nv_double2int_rn", core.dtype("int32")), }, is_pure=True, _builder=_builder) @@ -621,11 +609,11 @@ def double2int_rn(arg0, _builder=None): @core.extern def double2int_rz(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp64"), ): ("__imf_double2int_rz", core.dtype("int32")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp64"), ): ("__nv_double2int_rz", core.dtype("int32")), }, is_pure=True, _builder=_builder) @@ -633,11 +621,11 @@ def double2int_rz(arg0, _builder=None): @core.extern def double2int_rd(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp64"), ): ("__imf_double2int_rd", core.dtype("int32")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp64"), ): ("__nv_double2int_rd", core.dtype("int32")), }, is_pure=True, _builder=_builder) @@ -645,11 +633,11 @@ def double2int_rd(arg0, _builder=None): @core.extern def double2int_ru(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp64"), ): ("__imf_double2int_ru", core.dtype("int32")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp64"), ): ("__nv_double2int_ru", core.dtype("int32")), }, is_pure=True, _builder=_builder) @@ -657,11 +645,11 @@ def double2int_ru(arg0, _builder=None): @core.extern def double2uint_rn(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp64"), ): ("__imf_double2uint_rn", core.dtype("int32")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp64"), ): ("__nv_double2uint_rn", core.dtype("int32")), }, is_pure=True, _builder=_builder) @@ -669,11 +657,11 @@ def double2uint_rn(arg0, _builder=None): @core.extern def double2uint_rz(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp64"), ): ("__imf_double2uint_rz", core.dtype("int32")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp64"), ): ("__nv_double2uint_rz", core.dtype("int32")), }, is_pure=True, _builder=_builder) @@ -681,11 +669,11 @@ def double2uint_rz(arg0, _builder=None): @core.extern def double2uint_rd(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp64"), ): ("__imf_double2uint_rd", core.dtype("int32")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp64"), ): ("__nv_double2uint_rd", core.dtype("int32")), }, is_pure=True, _builder=_builder) @@ -693,11 +681,11 @@ def double2uint_rd(arg0, _builder=None): @core.extern def double2uint_ru(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp64"), ): ("__imf_double2uint_ru", core.dtype("int32")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp64"), ): ("__nv_double2uint_ru", core.dtype("int32")), }, is_pure=True, _builder=_builder) @@ -705,11 +693,11 @@ def double2uint_ru(arg0, _builder=None): @core.extern def int2double_rn(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("int32"), ): ("__imf_int2double_rn", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("int32"), ): ("__nv_int2double_rn", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -717,11 +705,11 @@ def int2double_rn(arg0, _builder=None): @core.extern def uint2double_rn(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("uint32"), ): ("__imf_uint2double_rn", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("uint32"), ): ("__nv_uint2double_rn", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -729,11 +717,11 @@ def uint2double_rn(arg0, _builder=None): @core.extern def float2int_rn(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__imf_float2int_rn", core.dtype("int32")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__nv_float2int_rn", core.dtype("int32")), }, is_pure=True, _builder=_builder) @@ -741,11 +729,11 @@ def float2int_rn(arg0, _builder=None): @core.extern def float2int_rz(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__imf_float2int_rz", core.dtype("int32")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__nv_float2int_rz", core.dtype("int32")), }, is_pure=True, _builder=_builder) @@ -753,11 +741,11 @@ def float2int_rz(arg0, _builder=None): @core.extern def float2int_rd(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__imf_float2int_rd", core.dtype("int32")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__nv_float2int_rd", core.dtype("int32")), }, is_pure=True, _builder=_builder) @@ -765,11 +753,11 @@ def float2int_rd(arg0, _builder=None): @core.extern def float2int_ru(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__imf_float2int_ru", core.dtype("int32")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__nv_float2int_ru", core.dtype("int32")), }, is_pure=True, _builder=_builder) @@ -777,11 +765,11 @@ def float2int_ru(arg0, _builder=None): @core.extern def float2uint_rn(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__imf_float2uint_rn", core.dtype("int32")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__nv_float2uint_rn", core.dtype("int32")), }, is_pure=True, _builder=_builder) @@ -789,11 +777,11 @@ def float2uint_rn(arg0, _builder=None): @core.extern def float2uint_rz(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__imf_float2uint_rz", core.dtype("int32")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__nv_float2uint_rz", core.dtype("int32")), }, is_pure=True, _builder=_builder) @@ -801,11 +789,11 @@ def float2uint_rz(arg0, _builder=None): @core.extern def float2uint_rd(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__imf_float2uint_rd", core.dtype("int32")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__nv_float2uint_rd", core.dtype("int32")), }, is_pure=True, _builder=_builder) @@ -813,11 +801,11 @@ def float2uint_rd(arg0, _builder=None): @core.extern def float2uint_ru(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__imf_float2uint_ru", core.dtype("int32")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__nv_float2uint_ru", core.dtype("int32")), }, is_pure=True, _builder=_builder) @@ -825,11 +813,11 @@ def float2uint_ru(arg0, _builder=None): @core.extern def int2float_rn(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("int32"), ): ("__imf_int2float_rn", core.dtype("fp32")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("int32"), ): ("__nv_int2float_rn", core.dtype("fp32")), }, is_pure=True, _builder=_builder) @@ -837,11 +825,11 @@ def int2float_rn(arg0, _builder=None): @core.extern def int2float_rz(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("int32"), ): ("__imf_int2float_rz", core.dtype("fp32")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("int32"), ): ("__nv_int2float_rz", core.dtype("fp32")), }, is_pure=True, _builder=_builder) @@ -849,11 +837,11 @@ def int2float_rz(arg0, _builder=None): @core.extern def int2float_rd(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("int32"), ): ("__imf_int2float_rd", core.dtype("fp32")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("int32"), ): ("__nv_int2float_rd", core.dtype("fp32")), }, is_pure=True, _builder=_builder) @@ -861,11 +849,11 @@ def int2float_rd(arg0, _builder=None): @core.extern def int2float_ru(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("int32"), ): ("__imf_int2float_ru", core.dtype("fp32")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("int32"), ): ("__nv_int2float_ru", core.dtype("fp32")), }, is_pure=True, _builder=_builder) @@ -873,11 +861,11 @@ def int2float_ru(arg0, _builder=None): @core.extern def uint2float_rn(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("uint32"), ): ("__imf_uint2float_rn", core.dtype("fp32")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("uint32"), ): ("__nv_uint2float_rn", core.dtype("fp32")), }, is_pure=True, _builder=_builder) @@ -885,11 +873,11 @@ def uint2float_rn(arg0, _builder=None): @core.extern def uint2float_rz(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("uint32"), ): ("__imf_uint2float_rz", core.dtype("fp32")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("uint32"), ): ("__nv_uint2float_rz", core.dtype("fp32")), }, is_pure=True, _builder=_builder) @@ -897,11 +885,11 @@ def uint2float_rz(arg0, _builder=None): @core.extern def uint2float_rd(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("uint32"), ): ("__imf_uint2float_rd", core.dtype("fp32")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("uint32"), ): ("__nv_uint2float_rd", core.dtype("fp32")), }, is_pure=True, _builder=_builder) @@ -909,11 +897,11 @@ def uint2float_rd(arg0, _builder=None): @core.extern def uint2float_ru(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("uint32"), ): ("__imf_uint2float_ru", core.dtype("fp32")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("uint32"), ): ("__nv_uint2float_ru", core.dtype("fp32")), }, is_pure=True, _builder=_builder) @@ -921,11 +909,11 @@ def uint2float_ru(arg0, _builder=None): @core.extern def hiloint2double(arg0, arg1, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1], { + return core.extern_elementwise("", "", [arg0, arg1], { (core.dtype("int32"), core.dtype("int32")): ("__imf_hiloint2double", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1], { + return core.extern_elementwise("", "", [arg0, arg1], { (core.dtype("int32"), core.dtype("int32")): ("__nv_hiloint2double", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -933,11 +921,11 @@ def hiloint2double(arg0, arg1, _builder=None): @core.extern def double2loint(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp64"), ): ("__imf_double2loint", core.dtype("int32")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp64"), ): ("__nv_double2loint", core.dtype("int32")), }, is_pure=True, _builder=_builder) @@ -945,11 +933,11 @@ def double2loint(arg0, _builder=None): @core.extern def double2hiint(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp64"), ): ("__imf_double2hiint", core.dtype("int32")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp64"), ): ("__nv_double2hiint", core.dtype("int32")), }, is_pure=True, _builder=_builder) @@ -957,11 +945,11 @@ def double2hiint(arg0, _builder=None): @core.extern def float2ll_rn(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__imf_float2ll_rn", core.dtype("int64")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__nv_float2ll_rn", core.dtype("int64")), }, is_pure=True, _builder=_builder) @@ -969,11 +957,11 @@ def float2ll_rn(arg0, _builder=None): @core.extern def float2ll_rz(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__imf_float2ll_rz", core.dtype("int64")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__nv_float2ll_rz", core.dtype("int64")), }, is_pure=True, _builder=_builder) @@ -981,11 +969,11 @@ def float2ll_rz(arg0, _builder=None): @core.extern def float2ll_rd(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__imf_float2ll_rd", core.dtype("int64")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__nv_float2ll_rd", core.dtype("int64")), }, is_pure=True, _builder=_builder) @@ -993,11 +981,11 @@ def float2ll_rd(arg0, _builder=None): @core.extern def float2ll_ru(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__imf_float2ll_ru", core.dtype("int64")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__nv_float2ll_ru", core.dtype("int64")), }, is_pure=True, _builder=_builder) @@ -1005,11 +993,11 @@ def float2ll_ru(arg0, _builder=None): @core.extern def float2ull_rn(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__imf_float2ull_rn", core.dtype("int64")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__nv_float2ull_rn", core.dtype("int64")), }, is_pure=True, _builder=_builder) @@ -1017,11 +1005,11 @@ def float2ull_rn(arg0, _builder=None): @core.extern def float2ull_rz(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__imf_float2ull_rz", core.dtype("int64")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__nv_float2ull_rz", core.dtype("int64")), }, is_pure=True, _builder=_builder) @@ -1029,11 +1017,11 @@ def float2ull_rz(arg0, _builder=None): @core.extern def float2ull_rd(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__imf_float2ull_rd", core.dtype("int64")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__nv_float2ull_rd", core.dtype("int64")), }, is_pure=True, _builder=_builder) @@ -1041,11 +1029,11 @@ def float2ull_rd(arg0, _builder=None): @core.extern def float2ull_ru(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__imf_float2ull_ru", core.dtype("int64")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__nv_float2ull_ru", core.dtype("int64")), }, is_pure=True, _builder=_builder) @@ -1053,11 +1041,11 @@ def float2ull_ru(arg0, _builder=None): @core.extern def double2ll_rn(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp64"), ): ("__imf_double2ll_rn", core.dtype("int64")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp64"), ): ("__nv_double2ll_rn", core.dtype("int64")), }, is_pure=True, _builder=_builder) @@ -1065,11 +1053,11 @@ def double2ll_rn(arg0, _builder=None): @core.extern def double2ll_rz(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp64"), ): ("__imf_double2ll_rz", core.dtype("int64")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp64"), ): ("__nv_double2ll_rz", core.dtype("int64")), }, is_pure=True, _builder=_builder) @@ -1077,11 +1065,11 @@ def double2ll_rz(arg0, _builder=None): @core.extern def double2ll_rd(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp64"), ): ("__imf_double2ll_rd", core.dtype("int64")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp64"), ): ("__nv_double2ll_rd", core.dtype("int64")), }, is_pure=True, _builder=_builder) @@ -1089,11 +1077,11 @@ def double2ll_rd(arg0, _builder=None): @core.extern def double2ll_ru(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp64"), ): ("__imf_double2ll_ru", core.dtype("int64")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp64"), ): ("__nv_double2ll_ru", core.dtype("int64")), }, is_pure=True, _builder=_builder) @@ -1101,11 +1089,11 @@ def double2ll_ru(arg0, _builder=None): @core.extern def double2ull_rn(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp64"), ): ("__imf_double2ull_rn", core.dtype("int64")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp64"), ): ("__nv_double2ull_rn", core.dtype("int64")), }, is_pure=True, _builder=_builder) @@ -1113,11 +1101,11 @@ def double2ull_rn(arg0, _builder=None): @core.extern def double2ull_rz(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp64"), ): ("__imf_double2ull_rz", core.dtype("int64")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp64"), ): ("__nv_double2ull_rz", core.dtype("int64")), }, is_pure=True, _builder=_builder) @@ -1125,11 +1113,11 @@ def double2ull_rz(arg0, _builder=None): @core.extern def double2ull_rd(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp64"), ): ("__imf_double2ull_rd", core.dtype("int64")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp64"), ): ("__nv_double2ull_rd", core.dtype("int64")), }, is_pure=True, _builder=_builder) @@ -1137,11 +1125,11 @@ def double2ull_rd(arg0, _builder=None): @core.extern def double2ull_ru(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp64"), ): ("__imf_double2ull_ru", core.dtype("int64")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp64"), ): ("__nv_double2ull_ru", core.dtype("int64")), }, is_pure=True, _builder=_builder) @@ -1149,11 +1137,11 @@ def double2ull_ru(arg0, _builder=None): @core.extern def ll2float_rn(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("int64"), ): ("__imf_ll2float_rn", core.dtype("fp32")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("int64"), ): ("__nv_ll2float_rn", core.dtype("fp32")), }, is_pure=True, _builder=_builder) @@ -1161,11 +1149,11 @@ def ll2float_rn(arg0, _builder=None): @core.extern def ll2float_rz(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("int64"), ): ("__imf_ll2float_rz", core.dtype("fp32")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("int64"), ): ("__nv_ll2float_rz", core.dtype("fp32")), }, is_pure=True, _builder=_builder) @@ -1173,11 +1161,11 @@ def ll2float_rz(arg0, _builder=None): @core.extern def ll2float_rd(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("int64"), ): ("__imf_ll2float_rd", core.dtype("fp32")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("int64"), ): ("__nv_ll2float_rd", core.dtype("fp32")), }, is_pure=True, _builder=_builder) @@ -1185,11 +1173,11 @@ def ll2float_rd(arg0, _builder=None): @core.extern def ll2float_ru(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("int64"), ): ("__imf_ll2float_ru", core.dtype("fp32")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("int64"), ): ("__nv_ll2float_ru", core.dtype("fp32")), }, is_pure=True, _builder=_builder) @@ -1197,12 +1185,12 @@ def ll2float_ru(arg0, _builder=None): @core.extern def ull2float_rn(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("uint64"), ): ("__imf_ull2float_rn", core.dtype("fp32")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("uint64"), ): ("__nv_ull2float_rn", core.dtype("fp32")), }, is_pure=True, _builder=_builder) @@ -1210,11 +1198,11 @@ def ull2float_rn(arg0, _builder=None): @core.extern def ull2float_rz(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("uint64"), ): ("__imf_ull2float_rz", core.dtype("fp32")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("uint64"), ): ("__nv_ull2float_rz", core.dtype("fp32")), }, is_pure=True, _builder=_builder) @@ -1222,11 +1210,11 @@ def ull2float_rz(arg0, _builder=None): @core.extern def ull2float_rd(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("uint64"), ): ("__imf_ull2float_rd", core.dtype("fp32")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("uint64"), ): ("__nv_ull2float_rd", core.dtype("fp32")), }, is_pure=True, _builder=_builder) @@ -1234,11 +1222,11 @@ def ull2float_rd(arg0, _builder=None): @core.extern def ull2float_ru(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("uint64"), ): ("__imf_ull2float_ru", core.dtype("fp32")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("uint64"), ): ("__nv_ull2float_ru", core.dtype("fp32")), }, is_pure=True, _builder=_builder) @@ -1246,11 +1234,11 @@ def ull2float_ru(arg0, _builder=None): @core.extern def ll2double_rn(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("int64"), ): ("__imf_ll2double_rn", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("int64"), ): ("__nv_ll2double_rn", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -1258,11 +1246,11 @@ def ll2double_rn(arg0, _builder=None): @core.extern def ll2double_rz(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("int64"), ): ("__imf_ll2double_rz", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("int64"), ): ("__nv_ll2double_rz", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -1270,11 +1258,11 @@ def ll2double_rz(arg0, _builder=None): @core.extern def ll2double_rd(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("int64"), ): ("__imf_ll2double_rd", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("int64"), ): ("__nv_ll2double_rd", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -1282,11 +1270,11 @@ def ll2double_rd(arg0, _builder=None): @core.extern def ll2double_ru(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("int64"), ): ("__imf_ll2double_ru", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("int64"), ): ("__nv_ll2double_ru", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -1294,11 +1282,11 @@ def ll2double_ru(arg0, _builder=None): @core.extern def ull2double_rn(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("uint64"), ): ("__imf_ull2double_rn", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("uint64"), ): ("__nv_ull2double_rn", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -1306,11 +1294,11 @@ def ull2double_rn(arg0, _builder=None): @core.extern def ull2double_rz(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("uint64"), ): ("__imf_ull2double_rz", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("uint64"), ): ("__nv_ull2double_rz", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -1318,11 +1306,11 @@ def ull2double_rz(arg0, _builder=None): @core.extern def ull2double_rd(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("uint64"), ): ("__imf_ull2double_rd", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("uint64"), ): ("__nv_ull2double_rd", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -1330,11 +1318,11 @@ def ull2double_rd(arg0, _builder=None): @core.extern def ull2double_ru(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("uint64"), ): ("__imf_ull2double_ru", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("uint64"), ): ("__nv_ull2double_ru", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -1342,11 +1330,11 @@ def ull2double_ru(arg0, _builder=None): @core.extern def int_as_float(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("int32"), ): ("__imf_int_as_float", core.dtype("fp32")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("int32"), ): ("__nv_int_as_float", core.dtype("fp32")), }, is_pure=True, _builder=_builder) @@ -1354,11 +1342,11 @@ def int_as_float(arg0, _builder=None): @core.extern def float_as_int(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__imf_float_as_int", core.dtype("int32")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__nv_float_as_int", core.dtype("int32")), }, is_pure=True, _builder=_builder) @@ -1366,11 +1354,11 @@ def float_as_int(arg0, _builder=None): @core.extern def uint_as_float(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("uint32"), ): ("__imf_uint_as_float", core.dtype("fp32")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("uint32"), ): ("__nv_uint_as_float", core.dtype("fp32")), }, is_pure=True, _builder=_builder) @@ -1378,11 +1366,11 @@ def uint_as_float(arg0, _builder=None): @core.extern def float_as_uint(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__imf_float_as_uint", core.dtype("int32")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__nv_float_as_uint", core.dtype("int32")), }, is_pure=True, _builder=_builder) @@ -1390,11 +1378,11 @@ def float_as_uint(arg0, _builder=None): @core.extern def longlong_as_double(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("int64"), ): ("__imf_longlong_as_double", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("int64"), ): ("__nv_longlong_as_double", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -1402,74 +1390,74 @@ def longlong_as_double(arg0, _builder=None): @core.extern def double_as_longlong(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp64"), ): ("__imf_double_as_longlong", core.dtype("int64")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp64"), ): ("__nv_double_as_longlong", core.dtype("int64")), }, is_pure=True, _builder=_builder) @core.extern def fast_sinf(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__nv_fast_sinf", core.dtype("fp32")), }, is_pure=True, _builder=_builder) @core.extern def fast_cosf(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__nv_fast_cosf", core.dtype("fp32")), }, is_pure=True, _builder=_builder) @core.extern def fast_log2f(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__nv_fast_log2f", core.dtype("fp32")), }, is_pure=True, _builder=_builder) @core.extern def fast_logf(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__nv_fast_logf", core.dtype("fp32")), }, is_pure=True, _builder=_builder) @core.extern def fast_expf(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__nv_fast_expf", core.dtype("fp32")), }, is_pure=True, _builder=_builder) @core.extern def fast_tanf(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__nv_fast_tanf", core.dtype("fp32")), }, is_pure=True, _builder=_builder) @core.extern def fast_exp10f(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__nv_fast_exp10f", core.dtype("fp32")), }, is_pure=True, _builder=_builder) @core.extern def fast_log10f(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__nv_fast_log10f", core.dtype("fp32")), }, is_pure=True, _builder=_builder) @core.extern def fast_powf(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1], { + return core.extern_elementwise("", "", [arg0, arg1], { (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fast_powf", core.dtype("fp32")), }, is_pure=True, _builder=_builder) @@ -1477,7 +1465,7 @@ def fast_powf(arg0, arg1, _builder=None): @core.extern def hadd(arg0, arg1, _builder=None): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1], { + "", "", [arg0, arg1], { (core.dtype("int32"), core.dtype("int32")): ("__nv_hadd", core.dtype("int32")), (core.dtype("uint32"), core.dtype("uint32")): ("__nv_uhadd", core.dtype("uint32")), }, is_pure=True, _builder=_builder) @@ -1487,13 +1475,13 @@ def hadd(arg0, arg1, _builder=None): def rhadd(arg0, arg1, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1], { + "", "", [arg0, arg1], { (core.dtype("int32"), core.dtype("int32")): ("__imf_rhadd", core.dtype("int32")), (core.dtype("uint32"), core.dtype("uint32")): ("__imf_urhadd", core.dtype("uint32")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1], { + "", "", [arg0, arg1], { (core.dtype("int32"), core.dtype("int32")): ("__nv_rhadd", core.dtype("int32")), (core.dtype("uint32"), core.dtype("uint32")): ("__nv_urhadd", core.dtype("uint32")), }, is_pure=True, _builder=_builder) @@ -1502,7 +1490,7 @@ def rhadd(arg0, arg1, _builder=None): @core.extern def sub_rn(arg0, arg1, _builder=None): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1], { + "", "", [arg0, arg1], { (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fsub_rn", core.dtype("fp32")), (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dsub_rn", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -1511,7 +1499,7 @@ def sub_rn(arg0, arg1, _builder=None): @core.extern def sub_rz(arg0, arg1, _builder=None): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1], { + "", "", [arg0, arg1], { (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fsub_rz", core.dtype("fp32")), (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dsub_rz", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -1520,7 +1508,7 @@ def sub_rz(arg0, arg1, _builder=None): @core.extern def sub_rd(arg0, arg1, _builder=None): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1], { + "", "", [arg0, arg1], { (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fsub_rd", core.dtype("fp32")), (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dsub_rd", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -1529,7 +1517,7 @@ def sub_rd(arg0, arg1, _builder=None): @core.extern def sub_ru(arg0, arg1, _builder=None): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1], { + "", "", [arg0, arg1], { (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fsub_ru", core.dtype("fp32")), (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dsub_ru", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -1537,7 +1525,7 @@ def sub_ru(arg0, arg1, _builder=None): @core.extern def rsqrt_rn(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [ + return core.extern_elementwise("", "", [ arg0, ], { (core.dtype("fp32"), ): ("__nv_frsqrt_rn", core.dtype("fp32")), @@ -1548,7 +1536,7 @@ def rsqrt_rn(arg0, _builder=None): def ffs(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [ + "", "", [ arg0, ], { (core.dtype("int32"), ): ("__imf_ffs", core.dtype("int32")), @@ -1556,7 +1544,7 @@ def ffs(arg0, _builder=None): }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [ + "", "", [ arg0, ], { (core.dtype("int32"), ): ("__nv_ffs", core.dtype("int32")), @@ -1568,7 +1556,7 @@ def ffs(arg0, _builder=None): def rint(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [ + "", "", [ arg0, ], { (core.dtype("fp32"), ): ("__imf_rintf", core.dtype("fp32")), @@ -1576,7 +1564,7 @@ def rint(arg0, _builder=None): }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [ + "", "", [ arg0, ], { (core.dtype("fp32"), ): ("__nv_rintf", core.dtype("fp32")), @@ -1588,7 +1576,7 @@ def rint(arg0, _builder=None): def llrint(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [ + "", "", [ arg0, ], { (core.dtype("fp32"), ): ("__imf_llrintf", core.dtype("int64")), @@ -1596,7 +1584,7 @@ def llrint(arg0, _builder=None): }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [ + "", "", [ arg0, ], { (core.dtype("fp32"), ): ("__nv_llrintf", core.dtype("int64")), @@ -1608,7 +1596,7 @@ def llrint(arg0, _builder=None): def nearbyint(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [ + "", "", [ arg0, ], { (core.dtype("fp32"), ): ("__imf_nearbyintf", core.dtype("fp32")), @@ -1616,7 +1604,7 @@ def nearbyint(arg0, _builder=None): }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [ + "", "", [ arg0, ], { (core.dtype("fp32"), ): ("__nv_nearbyintf", core.dtype("fp32")), @@ -1628,7 +1616,7 @@ def nearbyint(arg0, _builder=None): def isnan(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [ + "", "", [ arg0, ], { (core.dtype("fp32"), ): ("__imf_isnanf", core.dtype("int32")), @@ -1636,7 +1624,7 @@ def isnan(arg0, _builder=None): }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [ + "", "", [ arg0, ], { (core.dtype("fp32"), ): ("__nv_isnanf", core.dtype("int32")), @@ -1648,7 +1636,7 @@ def isnan(arg0, _builder=None): def signbit(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [ + "", "", [ arg0, ], { (core.dtype("fp32"), ): ("__imf_signbitf", core.dtype("int32")), @@ -1656,7 +1644,7 @@ def signbit(arg0, _builder=None): }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [ + "", "", [ arg0, ], { (core.dtype("fp32"), ): ("__nv_signbitf", core.dtype("int32")), @@ -1668,13 +1656,13 @@ def signbit(arg0, _builder=None): def copysign(arg0, arg1, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1], { + "", "", [arg0, arg1], { (core.dtype("fp32"), core.dtype("fp32")): ("__imf_copysignf", core.dtype("fp32")), (core.dtype("fp64"), core.dtype("fp64")): ("__imf_copysign", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1], { + "", "", [arg0, arg1], { (core.dtype("fp32"), core.dtype("fp32")): ("__nv_copysignf", core.dtype("fp32")), (core.dtype("fp64"), core.dtype("fp64")): ("__nv_copysign", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -1683,11 +1671,11 @@ def copysign(arg0, arg1, _builder=None): @core.extern def finitef(arg0, _builder=None): if is_spirv(): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__imf_isfinitef", core.dtype("int32")), }, is_pure=True, _builder=_builder) else: - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__nv_finitef", core.dtype("int32")), }, is_pure=True, _builder=_builder) @@ -1696,13 +1684,13 @@ def finitef(arg0, _builder=None): def isinf(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__imf_isinff", core.dtype("int32")), (core.dtype("fp64"), ): ("__imf_isinf", core.dtype("int32")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_isinff", core.dtype("int32")), (core.dtype("fp64"), ): ("__nv_isinfd", core.dtype("int32")), }, is_pure=True, _builder=_builder) @@ -1712,13 +1700,13 @@ def isinf(arg0, _builder=None): def nextafter(arg0, arg1, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1], { + "", "", [arg0, arg1], { (core.dtype("fp32"), core.dtype("fp32")): ("__imf_nextafterf", core.dtype("fp32")), (core.dtype("fp64"), core.dtype("fp64")): ("__imf_nextafter", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1], { + "", "", [arg0, arg1], { (core.dtype("fp32"), core.dtype("fp32")): ("__nv_nextafterf", core.dtype("fp32")), (core.dtype("fp64"), core.dtype("fp64")): ("__nv_nextafter", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -1728,13 +1716,13 @@ def nextafter(arg0, arg1, _builder=None): def sin(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__imf_sinf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__imf_sin", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_sinf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_sin", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -1744,13 +1732,13 @@ def sin(arg0, _builder=None): def cos(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__imf_cosf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__imf_cos", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_cosf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_cos", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -1760,13 +1748,13 @@ def cos(arg0, _builder=None): def sinpi(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__imf_sinpif", core.dtype("fp32")), (core.dtype("fp64"), ): ("__imf_sinpi", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_sinpif", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_sinpi", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -1776,13 +1764,13 @@ def sinpi(arg0, _builder=None): def cospi(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__imf_cospif", core.dtype("fp32")), (core.dtype("fp64"), ): ("__imf_cospi", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_cospif", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_cospi", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -1792,13 +1780,13 @@ def cospi(arg0, _builder=None): def tan(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__imf_tanf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__imf_tan", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_tanf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_tan", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -1808,13 +1796,13 @@ def tan(arg0, _builder=None): def log2(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__imf_log2f", core.dtype("fp32")), (core.dtype("fp64"), ): ("__imf_log2", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_log2f", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_log2", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -1824,13 +1812,13 @@ def log2(arg0, _builder=None): def exp(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__imf_expf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__imf_exp", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_expf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_exp", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -1840,13 +1828,13 @@ def exp(arg0, _builder=None): def exp10(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__imf_exp10f", core.dtype("fp32")), (core.dtype("fp64"), ): ("__imf_exp10", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_exp10f", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_exp10", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -1856,13 +1844,13 @@ def exp10(arg0, _builder=None): def cosh(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__imf_coshf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__imf_cosh", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_coshf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_cosh", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -1872,13 +1860,13 @@ def cosh(arg0, _builder=None): def sinh(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__imf_sinhf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__imf_sinh", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_sinhf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_sinh", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -1888,13 +1876,13 @@ def sinh(arg0, _builder=None): def tanh(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__imf_tanhf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__imf_tanh", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_tanhf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_tanh", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -1904,13 +1892,13 @@ def tanh(arg0, _builder=None): def atan2(arg0, arg1, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1], { + "", "", [arg0, arg1], { (core.dtype("fp32"), core.dtype("fp32")): ("__imf_atan2f", core.dtype("fp32")), (core.dtype("fp64"), core.dtype("fp64")): ("__imf_atan2", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1], { + "", "", [arg0, arg1], { (core.dtype("fp32"), core.dtype("fp32")): ("__nv_atan2f", core.dtype("fp32")), (core.dtype("fp64"), core.dtype("fp64")): ("__nv_atan2", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -1920,13 +1908,13 @@ def atan2(arg0, arg1, _builder=None): def atan(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__imf_atanf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__imf_atan", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_atanf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_atan", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -1936,13 +1924,13 @@ def atan(arg0, _builder=None): def asin(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__imf_asinf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__imf_asin", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_asinf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_asin", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -1952,13 +1940,13 @@ def asin(arg0, _builder=None): def acos(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__imf_acosf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__imf_acos", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_acosf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_acos", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -1968,13 +1956,13 @@ def acos(arg0, _builder=None): def log(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__imf_logf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__imf_log", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_logf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_log", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -1984,13 +1972,13 @@ def log(arg0, _builder=None): def log10(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__imf_log10f", core.dtype("fp32")), (core.dtype("fp64"), ): ("__imf_log10", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_log10f", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_log10", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -2000,13 +1988,13 @@ def log10(arg0, _builder=None): def log1p(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__imf_log1pf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__imf_log1p", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_log1pf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_log1p", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -2016,13 +2004,13 @@ def log1p(arg0, _builder=None): def acosh(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__imf_acoshf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__imf_acosh", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_acoshf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_acosh", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -2032,13 +2020,13 @@ def acosh(arg0, _builder=None): def asinh(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__imf_asinhf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__imf_asinh", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_asinhf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_asinh", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -2048,13 +2036,13 @@ def asinh(arg0, _builder=None): def atanh(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__imf_atanhf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__imf_atanh", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_atanhf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_atanh", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -2064,13 +2052,13 @@ def atanh(arg0, _builder=None): def expm1(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__imf_expm1f", core.dtype("fp32")), (core.dtype("fp64"), ): ("__imf_expm1", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_expm1f", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_expm1", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -2080,13 +2068,13 @@ def expm1(arg0, _builder=None): def hypot(arg0, arg1, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1], { + "", "", [arg0, arg1], { (core.dtype("fp32"), core.dtype("fp32")): ("__imf_hypotf", core.dtype("fp32")), (core.dtype("fp64"), core.dtype("fp64")): ("__imf_hypot", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1], { + "", "", [arg0, arg1], { (core.dtype("fp32"), core.dtype("fp32")): ("__nv_hypotf", core.dtype("fp32")), (core.dtype("fp64"), core.dtype("fp64")): ("__nv_hypot", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -2096,13 +2084,13 @@ def hypot(arg0, arg1, _builder=None): def rhypot(arg0, arg1, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1], { + "", "", [arg0, arg1], { (core.dtype("fp32"), core.dtype("fp32")): ("__imf_rhypotf", core.dtype("fp32")), (core.dtype("fp64"), core.dtype("fp64")): ("__imf_rhypot", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1], { + "", "", [arg0, arg1], { (core.dtype("fp32"), core.dtype("fp32")): ("__nv_rhypotf", core.dtype("fp32")), (core.dtype("fp64"), core.dtype("fp64")): ("__nv_rhypot", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -2112,13 +2100,13 @@ def rhypot(arg0, arg1, _builder=None): def norm3d(arg0, arg1, arg2, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1, arg2], { + "", "", [arg0, arg1, arg2], { (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__imf_norm3df", core.dtype("fp32")), (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__imf_norm3d", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1, arg2], { + "", "", [arg0, arg1, arg2], { (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_norm3df", core.dtype("fp32")), (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_norm3d", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -2128,13 +2116,13 @@ def norm3d(arg0, arg1, arg2, _builder=None): def rnorm3d(arg0, arg1, arg2, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1, arg2], { + "", "", [arg0, arg1, arg2], { (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__imf_rnorm3df", core.dtype("fp32")), (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__imf_rnorm3d", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1, arg2], { + "", "", [arg0, arg1, arg2], { (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_rnorm3df", core.dtype("fp32")), (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_rnorm3d", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -2144,7 +2132,7 @@ def rnorm3d(arg0, arg1, arg2, _builder=None): def norm4d(arg0, arg1, arg2, arg3, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1, arg2, arg3], { + "", "", [arg0, arg1, arg2, arg3], { (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__imf_norm4df", core.dtype("fp32")), (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): @@ -2152,7 +2140,7 @@ def norm4d(arg0, arg1, arg2, arg3, _builder=None): }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1, arg2, arg3], { + "", "", [arg0, arg1, arg2, arg3], { (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_norm4df", core.dtype("fp32")), (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): @@ -2164,7 +2152,7 @@ def norm4d(arg0, arg1, arg2, arg3, _builder=None): def rnorm4d(arg0, arg1, arg2, arg3, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1, arg2, arg3], { + "", "", [arg0, arg1, arg2, arg3], { (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__imf_rnorm4df", core.dtype("fp32")), (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): @@ -2172,7 +2160,7 @@ def rnorm4d(arg0, arg1, arg2, arg3, _builder=None): }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1, arg2, arg3], { + "", "", [arg0, arg1, arg2, arg3], { (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_rnorm4df", core.dtype("fp32")), (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): @@ -2184,13 +2172,13 @@ def rnorm4d(arg0, arg1, arg2, arg3, _builder=None): def cbrt(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__imf_cbrtf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__imf_cbrt", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_cbrtf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_cbrt", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -2200,13 +2188,13 @@ def cbrt(arg0, _builder=None): def rcbrt(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__imf_rcbrtf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__imf_rcbrt", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_rcbrtf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_rcbrt", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -2214,7 +2202,7 @@ def rcbrt(arg0, _builder=None): @core.extern def j0(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__nv_j0f", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_j0", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -2222,7 +2210,7 @@ def j0(arg0, _builder=None): @core.extern def j1(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__nv_j1f", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_j1", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -2230,7 +2218,7 @@ def j1(arg0, _builder=None): @core.extern def y0(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__nv_y0f", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_y0", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -2238,7 +2226,7 @@ def y0(arg0, _builder=None): @core.extern def y1(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp32"), ): ("__nv_y1f", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_y1", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -2247,7 +2235,7 @@ def y1(arg0, _builder=None): @core.extern def yn(arg0, arg1, _builder=None): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1], { + "", "", [arg0, arg1], { (core.dtype("int32"), core.dtype("fp32")): ("__nv_ynf", core.dtype("fp32")), (core.dtype("int32"), core.dtype("fp64")): ("__nv_yn", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -2256,7 +2244,7 @@ def yn(arg0, arg1, _builder=None): @core.extern def jn(arg0, arg1, _builder=None): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1], { + "", "", [arg0, arg1], { (core.dtype("int32"), core.dtype("fp32")): ("__nv_jnf", core.dtype("fp32")), (core.dtype("int32"), core.dtype("fp64")): ("__nv_jn", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -2265,7 +2253,7 @@ def jn(arg0, arg1, _builder=None): @core.extern def cyl_bessel_i0(arg0, _builder=None): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_cyl_bessel_i0f", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_cyl_bessel_i0", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -2274,7 +2262,7 @@ def cyl_bessel_i0(arg0, _builder=None): @core.extern def cyl_bessel_i1(arg0, _builder=None): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_cyl_bessel_i1f", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_cyl_bessel_i1", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -2284,13 +2272,13 @@ def cyl_bessel_i1(arg0, _builder=None): def erf(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__imf_erff", core.dtype("fp32")), (core.dtype("fp64"), ): ("__imf_erf", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_erff", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_erf", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -2300,13 +2288,13 @@ def erf(arg0, _builder=None): def erfinv(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__imf_erfcinvf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__imf_erfcinv", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_erfinvf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_erfinv", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -2316,13 +2304,13 @@ def erfinv(arg0, _builder=None): def erfc(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__imf_erfcf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__imf_erfc", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_erfcf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_erfc", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -2332,13 +2320,13 @@ def erfc(arg0, _builder=None): def erfcx(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__imf_erfcxf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__imf_erfcx", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_erfcxf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_erfcx", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -2347,7 +2335,7 @@ def erfcx(arg0, _builder=None): @core.extern def erfcinv(arg0, _builder=None): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_erfcinvf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_erfcinv", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -2357,13 +2345,13 @@ def erfcinv(arg0, _builder=None): def normcdfinv(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_cdnorminvf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_cdnorminv", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_normcdfinvf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_normcdfinv", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -2373,13 +2361,13 @@ def normcdfinv(arg0, _builder=None): def normcdf(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_cdnormf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_cdnorm", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_normcdff", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_normcdf", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -2389,13 +2377,13 @@ def normcdf(arg0, _builder=None): def lgamma(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__imf_lgammaf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__imf_lgamma", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_lgammaf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_lgamma", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -2405,13 +2393,13 @@ def lgamma(arg0, _builder=None): def ldexp(arg0, arg1, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1], { + "", "", [arg0, arg1], { (core.dtype("fp32"), core.dtype("int32")): ("__imf_ldexpf", core.dtype("fp32")), (core.dtype("fp64"), core.dtype("int32")): ("__imf_ldexp", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1], { + "", "", [arg0, arg1], { (core.dtype("fp32"), core.dtype("int32")): ("__nv_ldexpf", core.dtype("fp32")), (core.dtype("fp64"), core.dtype("int32")): ("__nv_ldexp", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -2421,13 +2409,13 @@ def ldexp(arg0, arg1, _builder=None): def scalbn(arg0, arg1, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1], { + "", "", [arg0, arg1], { (core.dtype("fp32"), core.dtype("int32")): ("__imf_scalbnf", core.dtype("fp32")), (core.dtype("fp64"), core.dtype("int32")): ("__imf_scalbn", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1], { + "", "", [arg0, arg1], { (core.dtype("fp32"), core.dtype("int32")): ("__nv_scalbnf", core.dtype("fp32")), (core.dtype("fp64"), core.dtype("int32")): ("__nv_scalbn", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -2437,13 +2425,13 @@ def scalbn(arg0, arg1, _builder=None): def fmod(arg0, arg1, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1], { + "", "", [arg0, arg1], { (core.dtype("fp32"), core.dtype("fp32")): ("__imf_fmodf", core.dtype("fp32")), (core.dtype("fp64"), core.dtype("fp64")): ("__imf_fmod", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1], { + "", "", [arg0, arg1], { (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmodf", core.dtype("fp32")), (core.dtype("fp64"), core.dtype("fp64")): ("__nv_fmod", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -2453,13 +2441,13 @@ def fmod(arg0, arg1, _builder=None): def remainder(arg0, arg1, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1], { + "", "", [arg0, arg1], { (core.dtype("fp32"), core.dtype("fp32")): ("__imf_remainderf", core.dtype("fp32")), (core.dtype("fp64"), core.dtype("fp64")): ("__imf_remainder", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1], { + "", "", [arg0, arg1], { (core.dtype("fp32"), core.dtype("fp32")): ("__nv_remainderf", core.dtype("fp32")), (core.dtype("fp64"), core.dtype("fp64")): ("__nv_remainder", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -2469,13 +2457,13 @@ def remainder(arg0, arg1, _builder=None): def fma(arg0, arg1, arg2, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1, arg2], { + "", "", [arg0, arg1, arg2], { (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf", core.dtype("fp32")), (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1, arg2], { + "", "", [arg0, arg1, arg2], { (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__imf_fmaf", core.dtype("fp32")), (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__imf_fma", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -2485,7 +2473,7 @@ def fma(arg0, arg1, arg2, _builder=None): def pow(arg0, arg1, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1], { + "", "", [arg0, arg1], { (core.dtype("fp32"), core.dtype("int32")): ("__imf_powif", core.dtype("fp32")), (core.dtype("fp64"), core.dtype("int32")): ("__imf_powi", core.dtype("fp64")), (core.dtype("fp32"), core.dtype("fp32")): ("__imf_powf", core.dtype("fp32")), @@ -2493,7 +2481,7 @@ def pow(arg0, arg1, _builder=None): }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1], { + "", "", [arg0, arg1], { (core.dtype("fp32"), core.dtype("int32")): ("__nv_powif", core.dtype("fp32")), (core.dtype("fp64"), core.dtype("int32")): ("__nv_powi", core.dtype("fp64")), (core.dtype("fp32"), core.dtype("fp32")): ("__nv_powf", core.dtype("fp32")), @@ -2505,13 +2493,13 @@ def pow(arg0, arg1, _builder=None): def tgamma(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__imf_tgammaf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__imf_tgamma", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_tgammaf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_tgamma", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -2521,13 +2509,13 @@ def tgamma(arg0, _builder=None): def round(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__imf_roundf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__imf_round", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_roundf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_round", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -2537,13 +2525,13 @@ def round(arg0, _builder=None): def llround(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__imf_llroundf", core.dtype("int64")), (core.dtype("fp64"), ): ("__imf_llround", core.dtype("int64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_llroundf", core.dtype("int64")), (core.dtype("fp64"), ): ("__nv_llround", core.dtype("int64")), }, is_pure=True, _builder=_builder) @@ -2553,13 +2541,13 @@ def llround(arg0, _builder=None): def fdim(arg0, arg1, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1], { + "", "", [arg0, arg1], { (core.dtype("fp32"), core.dtype("fp32")): ("__imf_fdimf", core.dtype("fp32")), (core.dtype("fp64"), core.dtype("fp64")): ("__imf_fdim", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0, arg1], { + "", "", [arg0, arg1], { (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdimf", core.dtype("fp32")), (core.dtype("fp64"), core.dtype("fp64")): ("__nv_fdim", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -2569,13 +2557,13 @@ def fdim(arg0, arg1, _builder=None): def ilogb(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__imf_ilogbf", core.dtype("int32")), (core.dtype("fp64"), ): ("__imf_ilogb", core.dtype("int32")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_ilogbf", core.dtype("int32")), (core.dtype("fp64"), ): ("__nv_ilogb", core.dtype("int32")), }, is_pure=True, _builder=_builder) @@ -2585,13 +2573,13 @@ def ilogb(arg0, _builder=None): def logb(arg0, _builder=None): if is_spirv(): return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__imf_logf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__imf_log", core.dtype("fp64")), }, is_pure=True, _builder=_builder) else: return core.extern_elementwise( - "libdevice", libdevice_path(), [arg0], { + "", "", [arg0], { (core.dtype("fp32"), ): ("__nv_logbf", core.dtype("fp32")), (core.dtype("fp64"), ): ("__nv_logb", core.dtype("fp64")), }, is_pure=True, _builder=_builder) @@ -2599,6 +2587,6 @@ def logb(arg0, _builder=None): @core.extern def isfinited(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + return core.extern_elementwise("", "", [arg0], { (core.dtype("fp64"), ): ("__nv_isfinited", core.dtype("int32")), }, is_pure=True, _builder=_builder) diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 06f35aa26f..98d368ac04 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -4,7 +4,6 @@ from typing import List, Optional, Sequence, Tuple, TypeVar from .._C.libtriton import ir -from ..common.build import is_hip from . import core as tl T = TypeVar('T') @@ -1201,19 +1200,6 @@ def atomic_xchg(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope # ===----------------------------------------------------------------------===// -def gpu_has_mfma() -> bool: - if not is_hip(): - return False - return True # mfma supported in ['gfx908', 'gfx90a'] - - -def mfma_supported(M, N, K, allow_tf32, ret_scalar_ty) -> bool: - if not gpu_has_mfma(): - return False - # TODO: Add check for configurations and types. - return True - - def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, allow_tf32: bool, max_num_imprecise_acc: int, out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor: @@ -1276,28 +1262,29 @@ def assert_dtypes_valid(lhs_dtype, rhs_dtype, options): N = rhs.type.shape[1] # Cast operands of types f16 and i8 for configurations where FMA only supported. - if is_hip() and not mfma_supported(M, N, lhs.type.shape[1], allow_tf32, ret_scalar_ty): - ret_cast_scalar_ty = tl.float32 if lhs.type.scalar.is_int() else ret_scalar_ty - lhs = cast(lhs, ret_cast_scalar_ty, builder) - rhs = cast(rhs, ret_cast_scalar_ty, builder) - if ret_cast_scalar_ty == tl.float16: - _0 = builder.create_splat(builder.get_fp16(0), [M, N]) - else: - _0 = builder.create_splat(builder.get_fp32(0), [M, N]) - ret_ty = tl.block_type(ret_cast_scalar_ty, [M, N]) - ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32), ret_ty) - return cast(ret, ret_scalar_ty, builder) - if is_hip() and mfma_supported(M, N, lhs.type.shape[1], allow_tf32, - ret_scalar_ty) and ret_scalar_ty.primitive_bitwidth < 32: - if lhs.type.scalar.is_int(): - ret_dot_scalar_ty = tl.int32 - _0 = builder.create_splat(builder.get_int32(0), [M, N]) - else: - ret_dot_scalar_ty = tl.float32 - _0 = builder.create_splat(builder.get_fp32(0), [M, N]) - ret_ty = tl.block_type(ret_dot_scalar_ty, [M, N]) - ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32), ret_ty) - return cast(ret, ret_scalar_ty, builder) + # TODO: builder should contain target information + # if is_hip() and not mfma_supported(M, N, lhs.type.shape[1], allow_tf32, ret_scalar_ty): + # ret_cast_scalar_ty = tl.float32 if lhs.type.scalar.is_int() else ret_scalar_ty + # lhs = cast(lhs, ret_cast_scalar_ty, builder) + # rhs = cast(rhs, ret_cast_scalar_ty, builder) + # if ret_cast_scalar_ty == tl.float16: + # _0 = builder.create_splat(builder.get_fp16(0), [M, N]) + # else: + # _0 = builder.create_splat(builder.get_fp32(0), [M, N]) + # ret_ty = tl.block_type(ret_cast_scalar_ty, [M, N]) + # ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32), ret_ty) + # return cast(ret, ret_scalar_ty, builder) + # if is_hip() and mfma_supported(M, N, lhs.type.shape[1], allow_tf32, + # ret_scalar_ty) and ret_scalar_ty.primitive_bitwidth < 32: + # if lhs.type.scalar.is_int(): + # ret_dot_scalar_ty = tl.int32 + # _0 = builder.create_splat(builder.get_int32(0), [M, N]) + # else: + # ret_dot_scalar_ty = tl.float32 + # _0 = builder.create_splat(builder.get_fp32(0), [M, N]) + # ret_ty = tl.block_type(ret_dot_scalar_ty, [M, N]) + # ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32), ret_ty) + # return cast(ret, ret_scalar_ty, builder) ret_ty = tl.block_type(ret_scalar_ty, [M, N]) if acc is None: acc_handle = builder.create_splat(_0, [M, N]) diff --git a/python/triton/runtime/backends/hip.c b/python/triton/runtime/backends/hip.c deleted file mode 100644 index 64679f8d6c..0000000000 --- a/python/triton/runtime/backends/hip.c +++ /dev/null @@ -1,108 +0,0 @@ -#define __HIP_PLATFORM_AMD__ -#include -#define PY_SSIZE_T_CLEAN -#include -#include -#include - -static inline void gpuAssert(hipError_t code, const char *file, int line) { - { - if (code != HIP_SUCCESS) { - { - const char *prefix = "Triton Error [HIP]: "; - const char *str = hipGetErrorString(code); - char err[1024] = {0}; - snprintf(err, 1024, "%s Code: %d, Message: %s", prefix, code, str); - PyGILState_STATE gil_state; - gil_state = PyGILState_Ensure(); - PyErr_SetString(PyExc_RuntimeError, err); - PyGILState_Release(gil_state); - } - } - } -} - -#define HIP_CHECK(ans) \ - { \ - gpuAssert((ans), __FILE__, __LINE__); \ - if (PyErr_Occurred()) \ - return NULL; \ - } - -static PyObject *getDeviceProperties(PyObject *self, PyObject *args) { - int device_id; - if (!PyArg_ParseTuple(args, "i", &device_id)) - return NULL; - - hipDeviceProp_t props; - HIP_CHECK(hipGetDeviceProperties(&props, device_id)); - - // create a struct to hold device properties - return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i, s:s}", "max_shared_mem", - props.sharedMemPerBlock, "multiprocessor_count", - props.multiProcessorCount, "sm_clock_rate", - props.clockRate, "mem_clock_rate", props.memoryClockRate, - "mem_bus_width", props.memoryBusWidth, "arch", - props.gcnArchName); -} - -static PyObject *loadBinary(PyObject *self, PyObject *args) { - const char *name; - const char *data; - Py_ssize_t data_size; - int shared; - int device; - if (!PyArg_ParseTuple(args, "ss#ii", &name, &data, &data_size, &shared, - &device)) { - return NULL; - } - - // set HIP options - hipJitOption opt[] = {hipJitOptionErrorLogBufferSizeBytes, - hipJitOptionErrorLogBuffer, - hipJitOptionInfoLogBufferSizeBytes, - hipJitOptionInfoLogBuffer, hipJitOptionLogVerbose}; - const unsigned int errbufsize = 8192; - const unsigned int logbufsize = 8192; - char _err[errbufsize]; - char _log[logbufsize]; - void *optval[] = {(void *)(uintptr_t)errbufsize, (void *)_err, - (void *)(uintptr_t)logbufsize, (void *)_log, (void *)1}; - - // launch HIP Binary - hipModule_t mod; - hipFunction_t fun; - HIP_CHECK(hipModuleLoadDataEx(&mod, data, 5, opt, optval)) - HIP_CHECK(hipModuleGetFunction(&fun, mod, name)); - - // get allocated registers and spilled registers from the function - int n_regs = 0; - int n_spills = 0; - if (PyErr_Occurred()) { - return NULL; - } - return Py_BuildValue("(KKii)", (uint64_t)mod, (uint64_t)fun, n_regs, - n_spills); -} - -static PyMethodDef ModuleMethods[] = { - {"load_binary", loadBinary, METH_VARARGS, - "Load provided hsaco into HIP driver"}, - {"get_device_properties", getDeviceProperties, METH_VARARGS, - "Get the properties for a given device"}, - {NULL, NULL, 0, NULL} // sentinel -}; - -static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "hip_utils", - NULL, // documentation - -1, // size - ModuleMethods}; - -PyMODINIT_FUNC PyInit_hip_utils(void) { - PyObject *m = PyModule_Create(&ModuleDef); - if (m == NULL) { - return NULL; - } - PyModule_AddFunctions(m, ModuleMethods); - return m; -} diff --git a/python/triton/runtime/build.py b/python/triton/runtime/build.py new file mode 100644 index 0000000000..bc5e3b1ce0 --- /dev/null +++ b/python/triton/runtime/build.py @@ -0,0 +1,104 @@ +import contextlib +import sys +import io +import sysconfig +import os +import shutil +import subprocess +import setuptools + + +def is_spirv(): + import torch + return torch.xpu.is_available() + + +@contextlib.contextmanager +def quiet(): + old_stdout, old_stderr = sys.stdout, sys.stderr + sys.stdout, sys.stderr = io.StringIO(), io.StringIO() + try: + yield + finally: + sys.stdout, sys.stderr = old_stdout, old_stderr + + +def _build(name, src, srcdir, library_dirs, include_dirs, libraries): + suffix = sysconfig.get_config_var('EXT_SUFFIX') + so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix)) + # try to avoid setuptools if possible + cc = os.environ.get("CC") + if cc is None: + # TODO: support more things here. + clang = shutil.which("clang") + gcc = shutil.which("gcc") + cc = gcc if gcc is not None else clang + if cc is None: + raise RuntimeError("Failed to find C compiler. Please specify via CC environment variable.") + # This function was renamed and made public in Python 3.10 + if hasattr(sysconfig, 'get_default_scheme'): + scheme = sysconfig.get_default_scheme() + else: + scheme = sysconfig._get_default_scheme() + # 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install + # path changes to include 'local'. This change is required to use triton with system-wide python. + if scheme == 'posix_local': + scheme = 'posix_prefix' + py_include_dir = sysconfig.get_paths(scheme=scheme)["include"] + include_dirs = include_dirs + [srcdir, py_include_dir] + + if is_spirv(): + icpx = None + cxx = os.environ.get("CXX") + if cxx is None: + clangpp = shutil.which("clang++") + icpx = shutil.which("icpx") + cxx = icpx if icpx is not None else clangpp + import numpy as np + numpy_include_dir = np.get_include() + include_dirs = include_dirs + [numpy_include_dir] + cxx_cmd = [cxx, src] + if icpx is not None: + cxx_cmd += ["-fsycl"] + cxx_cmd += ["-std=c++17", "-g", "-shared", "-fPIC", "-o", so] + cxx_cmd += [f'-l{lib}' for lib in libraries] + cxx_cmd += [f"-L{dir}" for dir in library_dirs] + cxx_cmd += [f"-I{dir}" for dir in include_dirs] + ret = subprocess.check_call(cxx_cmd) + else: + cc_cmd = [cc, src, "-O3", "-shared", "-fPIC", "-o", so] + cc_cmd += [f'-l{lib}' for lib in libraries] + cc_cmd += [f"-L{dir}" for dir in library_dirs] + cc_cmd += [f"-I{dir}" for dir in include_dirs] + ret = subprocess.check_call(cc_cmd) + + if ret == 0: + return so + # fallback on setuptools + extra_compile_args = [] + # extra arguments + extra_link_args = [] + # create extension module + ext = setuptools.Extension( + name=name, + language='c', + sources=[src], + include_dirs=include_dirs, + extra_compile_args=extra_compile_args + ['-O3'], + extra_link_args=extra_link_args, + library_dirs=library_dirs, + libraries=libraries, + ) + # build extension module + args = ['build_ext'] + args.append('--build-temp=' + srcdir) + args.append('--build-lib=' + srcdir) + args.append('-q') + args = dict( + name=name, + ext_modules=[ext], + script_args=args, + ) + with quiet(): + setuptools.setup(**args) + return so diff --git a/python/triton/runtime/cache.py b/python/triton/runtime/cache.py index 03df4f6fec..f56e806064 100644 --- a/python/triton/runtime/cache.py +++ b/python/triton/runtime/cache.py @@ -4,6 +4,7 @@ from abc import ABC, abstractmethod from pathlib import Path from typing import Dict, Optional +import hashlib def default_cache_dir(): @@ -157,3 +158,13 @@ def get_override_manager(key) -> CacheManager: def get_dump_manager(key) -> CacheManager: return __cache_cls(key, dump=True) + + +def make_so_cache_key(version_hash, signature, constants, ids, **kwargs): + # Get unique key for the compiled code + signature = {k: 'ptr' if v[0] == '*' else v for k, v in signature.items()} + key = f"{version_hash}-{''.join(signature.values())}-{constants}-{ids}" + for kw in kwargs: + key = f"{key}-{kwargs.get(kw)}" + key = hashlib.md5(key.encode("utf-8")).hexdigest() + return key diff --git a/python/triton/runtime/driver.py b/python/triton/runtime/driver.py index 799d60017a..afddd62277 100644 --- a/python/triton/runtime/driver.py +++ b/python/triton/runtime/driver.py @@ -1,622 +1,11 @@ -import abc -import hashlib -import os -import tempfile -from pathlib import Path -import functools +from ..backends import backends -from ..common.build import _build -from .cache import get_cache_manager -from ..runtime import driver -import intel_extension_for_pytorch as ipex - - -class DriverBase(metaclass=abc.ABCMeta): - CUDA = 0 - HIP = 1 - SPIRV = 2 - - @staticmethod - def third_party_dir(): - return os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "third_party") - - def __init__(self) -> None: - pass - - -# ----------------------------- -# Torch-GPU -# ----------------------------- - - -class FrameworkGPUDriver(DriverBase): - - def __init__(self): - # TODO: support other frameworks than torch - import torch - self.get_device_capability = torch.cuda.get_device_capability - try: - from torch._C import _cuda_getCurrentRawStream - self.get_current_stream = _cuda_getCurrentRawStream - except ImportError: - self.get_current_stream = lambda idx: torch.cuda.current_stream(idx).cuda_stream - self.get_current_device = torch.cuda.current_device - self.set_current_device = torch.cuda.set_device - - -# ----------------------------- -# CUDA -# ----------------------------- - - -class CudaUtils(object): - - def __new__(cls): - if not hasattr(cls, "instance"): - cls.instance = super(CudaUtils, cls).__new__(cls) - return cls.instance - - def __init__(self): - dirname = os.path.dirname(os.path.realpath(__file__)) - src = Path(os.path.join(dirname, "backends", "cuda.c")).read_text() - key = hashlib.md5(src.encode("utf-8")).hexdigest() - cache = get_cache_manager(key) - fname = "cuda_utils.so" - cache_path = cache.get_file(fname) - if cache_path is None: - with tempfile.TemporaryDirectory() as tmpdir: - src_path = os.path.join(tmpdir, "main.c") - with open(src_path, "w") as f: - f.write(src) - so = _build("cuda_utils", src_path, tmpdir) - with open(so, "rb") as f: - cache_path = cache.put(f.read(), fname, binary=True) - import importlib.util - - spec = importlib.util.spec_from_file_location("cuda_utils", cache_path) - mod = importlib.util.module_from_spec(spec) - spec.loader.exec_module(mod) - self.load_binary = mod.load_binary - self.get_device_properties = mod.get_device_properties - self.CUtensorMapDataType = mod.CUtensorMapDataType - self.CUtensorMapInterleave = mod.CUtensorMapInterleave - self.CUtensorMapSwizzle = mod.CUtensorMapSwizzle - self.CUtensorMapL2promotion = mod.CUtensorMapL2promotion - self.CUtensorMapFloatOOBfill = mod.CUtensorMapFloatOOBfill - self.cuTensorMapEncodeTiled = mod.cuTensorMapEncodeTiled - self.cuMemAlloc = mod.cuMemAlloc - self.cuMemcpyHtoD = mod.cuMemcpyHtoD - self.cuMemFree = mod.cuMemFree - self.cuOccupancyMaxActiveClusters = mod.cuOccupancyMaxActiveClusters - - -class TensorMapManager: - - def __init__(self): - self.tensormaps_device = {} - - def __getitem__(self, key: tuple): - if key in self.tensormaps_device: - return int(self.tensormaps_device[key]) - else: - (e, args) = key - t_tensormap = e.tensormap(args) - TENSORMAP_SIZE_IN_BYTES = 128 - t_tensormap_device = driver.utils.cuMemAlloc(TENSORMAP_SIZE_IN_BYTES) - driver.utils.cuMemcpyHtoD(t_tensormap_device, t_tensormap, TENSORMAP_SIZE_IN_BYTES) - self.tensormaps_device[key] = t_tensormap_device - return int(self.tensormaps_device[key]) - - def __del__(self): - for _, v in self.tensormaps_device.items(): - driver.utils.cuMemFree(v) - - -class CudaDriver(FrameworkGPUDriver): - tensormap_manager = TensorMapManager() - - def __new__(cls): - if not hasattr(cls, "instance"): - cls.instance = super(CudaDriver, cls).__new__(cls) - return cls.instance - - def __init__(self): - self.utils = CudaUtils() - self.backend = self.CUDA - self.binary_ext = "cubin" - super().__init__() - - @functools.lru_cache() - def get_current_target(self): - device = self.get_current_device() - capability = self.get_device_capability(device) - capability = capability[0] * 10 + capability[1] - return ("cuda", capability) - - def assemble_tensormap_to_arg(self, tensormaps_info, args): - args_with_tma = list(args) - if tensormaps_info is not None: - # tuple for hashable - args_ptr = tuple([arg.data_ptr() if hasattr(arg, 'data_ptr') else arg for arg in args]) - for i, e in enumerate(tensormaps_info): - args_with_tma.append(CudaDriver.tensormap_manager[(e, args_ptr)]) - return args_with_tma - - -# ----------------------------- -# HIP -# ----------------------------- - - -class HIPUtils(object): - - def __new__(cls): - if not hasattr(cls, "instance"): - cls.instance = super(HIPUtils, cls).__new__(cls) - return cls.instance - - def __init__(self): - dirname = os.path.dirname(os.path.realpath(__file__)) - src = Path(os.path.join(dirname, "backends", "hip.c")).read_text() - key = hashlib.md5(src.encode("utf-8")).hexdigest() - cache = get_cache_manager(key) - fname = "hip_utils.so" - cache_path = cache.get_file(fname) - if cache_path is None: - with tempfile.TemporaryDirectory() as tmpdir: - src_path = os.path.join(tmpdir, "main.c") - with open(src_path, "w") as f: - f.write(src) - so = _build("hip_utils", src_path, tmpdir) - with open(so, "rb") as f: - cache_path = cache.put(f.read(), fname, binary=True) - import importlib.util - - spec = importlib.util.spec_from_file_location("hip_utils", cache_path) - mod = importlib.util.module_from_spec(spec) - spec.loader.exec_module(mod) - self.load_binary = mod.load_binary - self.get_device_properties = mod.get_device_properties - - -class HIPDriver(FrameworkGPUDriver): - - def __new__(cls): - if not hasattr(cls, "instance"): - cls.instance = super(HIPDriver, cls).__new__(cls) - return cls.instance - - def __init__(self): - super().__init__() - self.utils = HIPUtils() - self.backend = self.HIP - self.binary_ext = "hsaco" - - def get_current_target(self): - device = self.get_current_device() - arch = self.utils.get_device_properties(device)['arch'] - return ("hip", arch.split(':')[0]) - - def assemble_tensormap_to_arg(self, tensormaps_info, args): - return args - - -# ----------------------------- -# SPIRV -# ----------------------------- - - -class SpirvUtils(object): - - def __new__(cls): - if not hasattr(cls, 'instance'): - cls.instance = super(SpirvUtils, cls).__new__(cls) - return cls.instance - - @staticmethod - def _generate_src(): - return """ - #include - #include - #include - #include - #include - #include - #include - #include - - #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION - #include - #include - - typedef struct l0_resc_handles { - ze_context_handle_t context; - ze_device_handle_t device; - ze_command_queue_handle_t queue; - ze_command_list_handle_t cmd_list; - }l0_resc_handles; - - std::unordered_map sycl_queue_map; - static ze_context_handle_t context = {nullptr}; - static ze_driver_handle_t driverHandle = {nullptr}; - static ze_event_pool_handle_t eventPoolHandle = {nullptr}; - - static std::vector devices; - - static inline void gpuAssert(ze_result_t code, const char *file, int line) - { - if (code != ZE_RESULT_SUCCESS) - { - const char* prefix = "Triton Error [ZE]: "; - std::string str = std::to_string(code); - char err[1024] = {0}; - strcat(err, prefix); - strcat(err, str.c_str()); - PyErr_SetString(PyExc_RuntimeError, err); - } - } - - #define ZE_CHECK(ans) { gpuAssert((ans), __FILE__, __LINE__); if(PyErr_Occurred()) return NULL; } - - static PyObject* getDeviceProperties(PyObject* self, PyObject* args){ - int device_id; - if(!PyArg_ParseTuple(args, "i", &device_id)) - return NULL; - - if (device_id > devices.size()) { - std::cout << "Device ID not found: " << device_id << std::endl; - return NULL; - } - - // Get device handle - ze_device_handle_t phDevice = devices[device_id]; - - // create a struct to hold device properties - ze_device_properties_t device_properties = {}; - device_properties.stype = ZE_STRUCTURE_TYPE_DEVICE_PROPERTIES; - zeDeviceGetProperties(phDevice, &device_properties); - - int multiprocessor_count = device_properties.numSlices * device_properties.numSubslicesPerSlice; - int sm_clock_rate = device_properties.coreClockRate; - - ze_device_compute_properties_t compute_properties = {}; - compute_properties.stype = ZE_STRUCTURE_TYPE_DEVICE_COMPUTE_PROPERTIES; - zeDeviceGetComputeProperties(phDevice, &compute_properties); - int max_shared_mem = compute_properties.maxSharedLocalMemory; - - uint32_t memoryCount = 0; - zeDeviceGetMemoryProperties(phDevice, &memoryCount, nullptr); - auto pMemoryProperties = new ze_device_memory_properties_t[memoryCount]; - for( uint32_t mem = 0; mem < memoryCount; ++mem ) - { - pMemoryProperties[mem].stype = ZE_STRUCTURE_TYPE_DEVICE_MEMORY_PROPERTIES; - pMemoryProperties[mem].pNext = nullptr; - } - zeDeviceGetMemoryProperties(phDevice, &memoryCount, pMemoryProperties); - // for( uint32_t mem = 0; mem < memoryCount; ++mem ) - // { - // std::cout << to_string( pMemoryProperties[ mem ] ) << std::endl; - // } - - int mem_clock_rate = pMemoryProperties[0].maxClockRate; - int mem_bus_width = pMemoryProperties[0].maxBusWidth; - - delete[] pMemoryProperties; - - return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i}", "max_shared_mem", max_shared_mem, - "multiprocessor_count", multiprocessor_count, - "sm_clock_rate", sm_clock_rate, - "mem_clock_rate", mem_clock_rate, - "mem_bus_width", mem_bus_width); - } - - static PyObject* loadBinary(PyObject* self, PyObject* args) { - const char* name; - int shared; - PyObject *py_bytes; - int device_id; - if(!PyArg_ParseTuple(args, "sSii", &name, &py_bytes, &shared, &device_id)) { - std::cout << "loadBinary arg parse failed" << std::endl; - return NULL; - } - - // uint8_t* data = (uint8_t*) PyBytes_AsString(py_bytes); - // int data_size = PyBytes_Size(py_bytes); - - if (device_id > devices.size()) { - std::cout << "Device ID not found: " << device_id << std::endl; - return NULL; - } - - ze_device_handle_t device = devices[device_id]; - - int32_t n_regs = 0; - int32_t n_spills = 0; - - ze_module_desc_t module_desc = {}; - module_desc.format = ZE_MODULE_FORMAT_IL_SPIRV; - module_desc.inputSize = PyBytes_Size(py_bytes); - module_desc.pInputModule = (uint8_t*) PyBytes_AsString(py_bytes); - ze_module_handle_t module; - // std::cout << "SPIRV binary size: " << module_desc.inputSize << std::endl; - ZE_CHECK(zeModuleCreate(context, device, &module_desc, &module, nullptr)); - - // std::cout << "loadBinary zeModuleCreated" << std::endl; - ze_kernel_desc_t kernel_desc = {}; - kernel_desc.pKernelName = name; - ze_kernel_handle_t fun; - ZE_CHECK(zeKernelCreate(module, &kernel_desc, &fun)); - - // std::cout << "loadBinary zeKernelCreated" << std::endl; - - if(PyErr_Occurred()) { - std::cout << "loadBinary error occurred" << std::endl; - return NULL; - } - - return Py_BuildValue("(KKii)", (uint64_t)module, (uint64_t)fun, n_regs, n_spills); - } - - bool update(sycl::queue sycl_queue) { - // Get l0-context - auto sycl_context = sycl_queue.get_context(); - ze_context_handle_t hCtxt = get_native(sycl_context); - // Get l0-device - std::vector sycl_devices = sycl_context.get_devices(); - ze_device_handle_t hDev = get_native(sycl_devices[0]); - // Get l0-queue - bool immediate_cmd_list = false; - std::variant queue_var = get_native(sycl_queue); - auto l0_queue = std::get_if(&queue_var); - if (l0_queue == nullptr) { - auto imm_cmd_list = std::get_if(&queue_var); - if (imm_cmd_list == nullptr) { - return false; - } - immediate_cmd_list = true; - sycl_queue_map[sycl_queue].cmd_list = *imm_cmd_list; - } - sycl_queue_map[sycl_queue].context = hCtxt; - sycl_queue_map[sycl_queue].device = hDev; - sycl_queue_map[sycl_queue].queue = immediate_cmd_list ? 0 : *l0_queue; - - // Update global data - context = sycl_queue_map[sycl_queue].context; - uint32_t deviceCount = std::min(sycl_devices.size(), devices.size()); - for (uint32_t i = 0; i < deviceCount; ++i) { - devices[i] = sycl::get_native(sycl_devices[i]); - } - - return true; - } - - static PyObject* initContext(PyObject* self, PyObject* args) { - void* queue; - if(!PyArg_ParseTuple(args, "K", &queue)) - return NULL; - sycl::queue* sycl_queue = static_cast(queue); - if(sycl_queue_map.find(*sycl_queue) == sycl_queue_map.end()) { - update(*sycl_queue); - } - context = sycl_queue_map[*sycl_queue].context; - return Py_BuildValue("(K)", (uint64_t)context); - } - - static PyObject* initEventPool(PyObject* self, PyObject* args) { - // Create event pool - ze_event_pool_desc_t tsEventPoolDesc = { - ZE_STRUCTURE_TYPE_EVENT_POOL_DESC, - nullptr, - ZE_EVENT_POOL_FLAG_HOST_VISIBLE, // all events in pool are visible to Host - 1 // count - }; - ZE_CHECK(zeEventPoolCreate(context, &tsEventPoolDesc, 0, nullptr, &eventPoolHandle)); - - return Py_BuildValue("(K)", (uint64_t)eventPoolHandle); - // Py_RETURN_NONE; - } - - static PyObject* initDevices(PyObject* self, PyObject *args) { - void* queue; - if(!PyArg_ParseTuple(args, "K", &queue)) - return NULL; - sycl::queue* sycl_queue = static_cast(queue); - - auto sycl_context = sycl_queue->get_context(); - - // Get l0-device - std::vector sycl_devices = sycl_context.get_devices(); - - // Retrieve devices - uint32_t deviceCount = sycl_devices.size(); - for (uint32_t i = 0; i < deviceCount; ++i) { - devices.push_back(sycl::get_native(sycl_devices[i])); - } - - // npy_intp dims[1]; - // dims[0] = deviceCount; - // std::cout << "Before PyArray_SimpleNewFromData: " << devices.size() << " " << devices.data()[0] << std::endl; - // PyObject* arr = PyArray_SimpleNewFromData(1, dims, NPY_UINT64, reinterpret_cast(devices.data())); - // std::cout << "After PyArray_SimpleNewFromData: " << devices.data()[0] << std::endl; - // PyObject* ret = Py_BuildValue("(O)", arr); - // std::cout << "After Py_BuildValue" << std::endl; - // return ret; - return Py_BuildValue("(i)", deviceCount); - // Py_RETURN_NONE; - } - - static PyObject* getL0ImmCommandList(PyObject* self, PyObject* args) { - void* queue; - if(!PyArg_ParseTuple(args, "K", &queue)) - return NULL; - sycl::queue* sycl_queue = static_cast(queue); - - if(sycl_queue_map.find(*sycl_queue) == sycl_queue_map.end()) { - update(*sycl_queue); - } - return Py_BuildValue("(K)", (uint64_t)(sycl_queue_map[*sycl_queue].cmd_list)); - } - static PyObject* getL0Queue(PyObject* self, PyObject* args) { - void* queue; - if(!PyArg_ParseTuple(args, "K", &queue)) - return NULL; - sycl::queue* sycl_queue = static_cast(queue); - if(sycl_queue_map.find(*sycl_queue) == sycl_queue_map.end()) { - update(*sycl_queue); - } - return Py_BuildValue("(K)", (uint64_t)(sycl_queue_map[*sycl_queue].queue)); - } - static PyObject* getL0DevPtr(PyObject* self, PyObject* args) { - void* queue; - if(!PyArg_ParseTuple(args, "K", &queue)) - return NULL; - sycl::queue* sycl_queue = static_cast(queue); - if(sycl_queue_map.find(*sycl_queue) == sycl_queue_map.end()) { - update(*sycl_queue); - } - return Py_BuildValue("(K)", (uint64_t)(sycl_queue_map[*sycl_queue].device)); - } - static PyObject* getL0CtxtPtr(PyObject* self, PyObject* args) { - void* queue; - if(!PyArg_ParseTuple(args, "K", &queue)) - return NULL; - sycl::queue* sycl_queue = static_cast(queue); - if(sycl_queue_map.find(*sycl_queue) == sycl_queue_map.end()) { - update(*sycl_queue); - } - return Py_BuildValue("(K)", (uint64_t)(sycl_queue_map[*sycl_queue].context)); - } - static PyObject* isUsingICL(PyObject* self, PyObject* args) { - void* queue; - if(!PyArg_ParseTuple(args, "K", &queue)) - return NULL; - sycl::queue* sycl_queue = static_cast(queue); - if(sycl_queue_map.find(*sycl_queue) == sycl_queue_map.end()) { - update(*sycl_queue); - } - uint32_t using_icl = sycl_queue_map[*sycl_queue].cmd_list != 0 ? 1 : 0; - return Py_BuildValue("(i)", using_icl); - } - - static PyMethodDef ModuleMethods[] = { - {"load_binary", loadBinary, METH_VARARGS, "Load provided SPV into ZE driver"}, - {"get_device_properties", getDeviceProperties, METH_VARARGS, "Get the properties for a given device"}, - {"init_context", initContext, METH_VARARGS, "Initialize the ZE GPU context"}, - {"init_devices", initDevices, METH_VARARGS, "Initialize the ZE GPU devices and return device count"}, - {"init_event_pool", initEventPool, METH_VARARGS, "Initialize ZE event pool"}, - {"get_l0_imm_cmd_list", getL0ImmCommandList, METH_VARARGS, "Get l0 command list in case of immediate command list"}, - {"get_l0_queue", getL0Queue, METH_VARARGS, "Get l0 queue from sycl queue"}, - {"get_l0_dev_ptr", getL0DevPtr, METH_VARARGS, "Extract l0 device pointer from sycl queue"}, - {"get_l0_ctxt_ptr", getL0CtxtPtr, METH_VARARGS, "Extract l0 context pointer from sycl queue"}, - {"is_using_icl", isUsingICL, METH_VARARGS, "Extract sycl queue info, if it is using ICL"}, - {NULL, NULL, 0, NULL} // sentinel - }; - - static struct PyModuleDef ModuleDef = { - PyModuleDef_HEAD_INIT, - "spirv_utils", - NULL, //documentation - -1, //size - ModuleMethods - }; - - PyMODINIT_FUNC PyInit_spirv_utils(void) { - PyObject *m = PyModule_Create(&ModuleDef); - if(m == NULL) { - return NULL; - } - PyModule_AddFunctions(m, ModuleMethods); - return m; - } - """ - - def __init__(self): - src = self._generate_src() - key = hashlib.md5(src.encode("utf-8")).hexdigest() - cache = get_cache_manager(key) - fname = "spirv_utils.so" - cache_path = cache.get_file(fname) - if cache_path is None: - with tempfile.TemporaryDirectory() as tmpdir: - src_path = os.path.join(tmpdir, "main.cpp") - with open(src_path, "w") as f: - f.write(src) - so = _build("spirv_utils", src_path, tmpdir) - with open(so, "rb") as f: - cache_path = cache.put(f.read(), fname, binary=True) - import importlib.util - spec = importlib.util.spec_from_file_location("spirv_utils", cache_path) - mod = importlib.util.module_from_spec(spec) - spec.loader.exec_module(mod) - self.load_binary = mod.load_binary - self.get_device_properties = mod.get_device_properties - self.get_l0_queue = mod.get_l0_queue - self.get_l0_imm_cmd_list = mod.get_l0_imm_cmd_list - self.get_l0_dev_ptr = mod.get_l0_dev_ptr - self.get_l0_ctxt_ptr = mod.get_l0_ctxt_ptr - self.is_using_icl = mod.is_using_icl - self.context = mod.init_context(ipex.xpu.current_stream().sycl_queue) - self.device_count = mod.init_devices(ipex.xpu.current_stream().sycl_queue) - self.event_pool = mod.init_event_pool()[0] - self.current_device = 0 if self.device_count[0] > 0 else -1 - - def get_current_device(instance): - return instance.current_device - - def get_event_pool(instance): - return instance.event_pool - - def set_current_device(instance, idx): - assert instance.device_count[0] > idx, "Device id not found" - instance.current_device = idx - - def get_device_capability(instance, idx): - return (0, 0) - - -class SpirvDriver(DriverBase): - - def __new__(cls): - if not hasattr(cls, 'instance'): - cls.instance = super(SpirvDriver, cls).__new__(cls) - return cls.instance - - def __init__(self): - self.utils = SpirvUtils() - self.backend = self.SPIRV - self.binary_ext = "spv" - self.get_current_stream = self.get_current_stream - self.get_current_device = self.utils.get_current_device - - def get_current_stream(self, device): - # FIXME - return 0 - - @functools.lru_cache() - def get_current_target(self): - return ("xpu", 0) - - def assemble_tensormap_to_arg(self, tensormaps_info, args): - args_ptr = tuple([arg.data_ptr() if hasattr(arg, 'data_ptr') else arg for arg in args]) - return args_ptr - - -class UnsupportedDriver(DriverBase): - - def __new__(cls): - if not hasattr(cls, "instance"): - cls.instance = super(UnsupportedDriver, cls).__new__(cls) - return cls.instance - - def __init__(self): - self.utils = None - self.backend = None - - -# ----------------------------- -# Driver -# ----------------------------- +def _create_driver(): + actives = [x.driver for x in backends.values() if x.driver.is_active()] + if len(actives) != 1: + raise RuntimeError(f"{len(actives)} active drivers ({actives}). There should only be one.") + return actives[0]() class LazyProxy: @@ -654,17 +43,4 @@ def __str__(self): return str(self._obj) -def initialize_driver(): - import torch - - if torch.version.hip is not None: - return HIPDriver() - elif torch.cuda.is_available(): - return CudaDriver() - elif torch.xpu.is_available(): - return SpirvDriver() - else: - return UnsupportedDriver() - - -driver = LazyProxy(initialize_driver) +driver = LazyProxy(_create_driver) diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index d496668c3d..d226caa050 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -1,5 +1,4 @@ from __future__ import annotations, division - import ast import functools import hashlib @@ -12,7 +11,6 @@ import torch import intel_extension_for_pytorch as ipex -from ..common.backend import get_backend, get_cuda_version_key from .interpreter import InterpretedFunction from ..runtime.driver import driver @@ -376,8 +374,7 @@ def _get_arg_data_ptr(self, arg) -> str: return arg def run(self, *args, grid, warmup, **kwargs): - from ..compiler import CompiledKernel, compile, ASTSource - from ..compiler.backends import make_backend + from ..compiler import CompiledKernel, compile, ASTSource, make_backend # deprecated arguments assert "device_type" not in kwargs, "device_type option is deprecated; current target will be used" assert "device" not in kwargs, "device option is deprecated; current device will be used" @@ -413,11 +410,7 @@ def run(self, *args, grid, warmup, **kwargs): sig_key = tuple(arg.signature_key() for arg in args if not arg.param.is_constexpr) spec_key = tuple(arg.specialization_key() for arg in args if not arg.param.do_not_specialize) constexpr_key = tuple(arg.value for arg in args if arg.param.is_constexpr) - version_key = get_cuda_version_key() - if target[0] in ["xpu"]: - # FIXME - version_key = "" - key = (version_key, sig_key, constexpr_key, spec_key, options) + key = (sig_key, constexpr_key, spec_key, options) # Kernel is not cached; we have to compile. if key not in self.cache[device]: configs = (self._get_config(*[arg.value for arg in args]), ) @@ -648,7 +641,6 @@ class TensorWrapper: def __init__(self, base, dtype): self.dtype = dtype self.base = base - self.is_cuda = base.is_cuda self.device = base.device self.shape = self.base.shape diff --git a/python/triton/tools/build_extern.py b/python/triton/tools/build_extern.py index 6f00e81925..8f0168d59d 100644 --- a/python/triton/tools/build_extern.py +++ b/python/triton/tools/build_extern.py @@ -268,19 +268,8 @@ def _output_stubs(self) -> str: # arg_type_symbol_dict = {[arg_type]: {(symbol, ret_type)}} # return core.extern_elementwise("libdevice", , , , _builder) import_str = "from . import core\n" - import_str += "import os\n" - import_str += "import functools\n" header_str = "" - header_str += "@functools.lru_cache()\n" - header_str += "def libdevice_path():\n" - header_str += " import torch\n" - header_str += " third_party_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), \"..\", \"third_party\")\n" - header_str += " if torch.version.hip is None:\n" - header_str += " default = os.path.join(third_party_dir, \"cuda\", \"lib\", \"libdevice.10.bc\")\n" - header_str += " else:\n" - header_str += " default = ''\n" - header_str += " return os.getenv(\"TRITON_LIBDEVICE_PATH\", default)\n" func_str = "" for symbols in self._symbol_groups.values(): func_str += "@core.extern\n" diff --git a/python/triton/tools/compile.py b/python/triton/tools/compile.py index 51193ba3d8..9bd6fcab20 100644 --- a/python/triton/tools/compile.py +++ b/python/triton/tools/compile.py @@ -8,7 +8,7 @@ import triton from triton.compiler.code_generator import kernel_suffix -from triton.compiler.make_launcher import ty_to_cpp +from triton.backends.xpu.driver import ty_to_cpp desc = """ Triton ahead-of-time compiler: diff --git a/scripts/compile-triton.sh b/scripts/compile-triton.sh index 2826cec3fc..2872a22beb 100755 --- a/scripts/compile-triton.sh +++ b/scripts/compile-triton.sh @@ -82,7 +82,6 @@ function build_llvm { -DLLVM_ENABLE_ASSERTIONS=true \ -DLLVM_ENABLE_PROJECTS="mlir" \ -DLLVM_TARGETS_TO_BUILD="X86;NVPTX;AMDGPU" \ - -DLLVM_EXPERIMENTAL_TARGETS_TO_BUILD="SPIRV" \ -DLLVM_INSTALL_UTILS=true \ -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \ -DCMAKE_INSTALL_PREFIX=$PACKAGES_DIR/llvm \ diff --git a/third_party/amd_hip_backend b/third_party/amd_hip_backend deleted file mode 160000 index d0ad70d55d..0000000000 --- a/third_party/amd_hip_backend +++ /dev/null @@ -1 +0,0 @@ -Subproject commit d0ad70d55df3ebe11cc80bbb364a91551e6b6248 diff --git a/third_party/cuda/CMakeLists.txt b/third_party/cuda/CMakeLists.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/third_party/cuda/backend/__init__.py b/third_party/cuda/backend/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/third_party/cuda/backend/compiler.py b/third_party/cuda/backend/compiler.py new file mode 100644 index 0000000000..d09cae06f2 --- /dev/null +++ b/third_party/cuda/backend/compiler.py @@ -0,0 +1,567 @@ +from triton.backends.compiler import BaseBackend +from triton._C.libtriton import ir, passes, llvm, nvidia +from triton.runtime import driver +from dataclasses import dataclass +import functools +from typing import Any +import hashlib +import re +import tempfile +import signal +import os +import subprocess +from pathlib import Path + +# ------------- TMA stuff ----------------# +# +# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files +# (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, +# publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + +def dummy_tensormaps_info(n=2): + ret = [] + for i in range(n): + ret.append(InfoFromBackendForTensorMap(dummy=True)) + return ret + + +def parse_tma_info(infos, ids_of_folded_args): + ret = [] + for info in infos: + e = InfoFromBackendForTensorMap(infos=info) + e.ids_of_folded_args = ids_of_folded_args + ret.append(e) + return ret + + +def get_tma_mapping(tensormaps_info): + ret = {} + if tensormaps_info is not None: + for i, e in enumerate(tensormaps_info): + ret.update(e.get_address_tma_mapping()) + else: + ret = None + return ret + + +def get_ids_of_tensormaps(tensormaps_info): + ret = None + # order is not relevant + if tensormaps_info is not None: + ret = [e.get_id_of_tensormap() for e in tensormaps_info] + return ret + + +# decouple information for tensormap from backend +# please ignore the naming style, xx_yy is compiler.py style, xxYy is to comply with cuda tensormap style +# mixing style is for readability +class InfoFromBackendForTensorMap: + N = 2 + n = 0 + ntma = 0 + + def __init__(self, infos=None, dummy=False): + self.dummy = dummy + self.ids_of_folded_args = () + if not dummy and not isinstance(infos, dict): + self._extract_info_from_backend(infos) + elif not dummy and isinstance(infos, dict): + self._extract_info_from_dict(infos) + elif dummy: + self._dummy() + + def _dummy(self): + assert InfoFromBackendForTensorMap.n < InfoFromBackendForTensorMap.N + if InfoFromBackendForTensorMap.n == 0: + self.tensorDataType = driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT16"] + self.tensorRank = 4 + self.globalAddressArgIdx = 0 + self.globalStridesArgIdx = [7, 6, -1, -1] + self.globalDimsArgIdx = [5, 3, -1, -1] + self.boxDims = [16, 64, 1, 1] + self.elementStrides = [1, 1, 1, 1] + self.interleave = driver.utils.CUtensorMapInterleave["CU_TENSOR_MAP_INTERLEAVE_NONE"] + self.swizzle = driver.utils.CUtensorMapSwizzle["CU_TENSOR_MAP_SWIZZLE_32B"] + self.l2Promotion = driver.utils.CUtensorMapL2promotion["CU_TENSOR_MAP_L2_PROMOTION_L2_128B"] + self.TMADescArgIdx = 11 + self.oobFill = driver.utils.CUtensorMapFloatOOBfill["CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE"] + InfoFromBackendForTensorMap.n += 1 + return + if InfoFromBackendForTensorMap.n == 1: + self.tensorDataType = driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT16"] + self.tensorRank = 4 + self.globalAddressArgIdx = 1 + self.globalStridesArgIdx = [7, 6, -1, -1] + self.globalDimsArgIdx = [5, 3, -1, -1] + self.boxDims = [16, 64, 1, 1] + self.elementStrides = [1, 1, 1, 1] + self.interleave = driver.utils.CUtensorMapInterleave["CU_TENSOR_MAP_INTERLEAVE_NONE"] + self.swizzle = driver.utils.CUtensorMapSwizzle["CU_TENSOR_MAP_SWIZZLE_32B"] + self.l2Promotion = driver.utils.CUtensorMapL2promotion["CU_TENSOR_MAP_L2_PROMOTION_L2_128B"] + self.TMADescArgIdx = 12 + self.oobFill = driver.utils.CUtensorMapFloatOOBfill["CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE"] + InfoFromBackendForTensorMap.n += 1 + return + + def _extract_info_from_backend(self, infos): + self.tensorDataType = infos.tensorDataType + self.tensorRank = infos.tensorRank + self.globalAddressArgIdx = infos.globalAddressArgIdx + self.globalStridesArgIdx = infos.globalStridesArgIdx + self.globalDimsArgIdx = infos.globalDimsArgIdx + self.boxDims = infos.boxDims + self.elementStrides = infos.elementStrides + self.interleave = infos.interleave + self.swizzle = infos.swizzle + self.l2Promotion = infos.l2Promotion + self.oobFill = infos.oobFill + self.TMADescArgIdx = infos.TMADescArgIdx + + # dict could be from cached metadata json + def _extract_info_from_dict(self, infos: dict): + self.tensorDataType = infos['tensorDataType'] + self.tensorRank = infos['tensorRank'] + self.globalAddressArgIdx = infos['globalAddressArgIdx'] + self.globalStridesArgIdx = infos['globalStridesArgIdx'] + self.globalDimsArgIdx = infos['globalDimsArgIdx'] + self.boxDims = infos['boxDims'] + self.elementStrides = infos['elementStrides'] + self.interleave = infos['interleave'] + self.swizzle = infos['swizzle'] + self.l2Promotion = infos['l2Promotion'] + self.oobFill = infos['oobFill'] + self.TMADescArgIdx = infos['TMADescArgIdx'] + + def get_address_tma_mapping(self): + return {self.globalAddressArgIdx: self.TMADescArgIdx + len(self.ids_of_folded_args)} + + def get_id_of_tensormap(self): + return self.TMADescArgIdx + len(self.ids_of_folded_args) + + def getTMADescArgIdx(self): + return self.TMADescArgIdx + + # dtype:cuda.CUtensorMapDataType | int + def bytes_from_type(self, dtype): + return { + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT8"]: 1, + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT16"]: 2, + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT32"]: 4, + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_INT32"]: 4, + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT64"]: 8, + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_INT64"]: 8, + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT16"]: 2, + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT32"]: 4, + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT64"]: 8, + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_BFLOAT16"]: 2, + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ"]: 4, + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_TFLOAT32"]: 4, + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ"]: 4 + }[dtype] + + def getTensorMapDataType(self): + return self.tensorDataType + + def getInterleave(self): + return self.interleave + + def getSwizzle(self): + return self.swizzle + + def getL2Promotion(self): + return self.l2Promotion + + def getOobFill(self): + return self.oobFill + + def getTensorRank(self): + return self.tensorRank + + def getBoxDims(self): + return self.boxDims + + def getElementStrides(self): + return self.elementStrides + + def getGlobalAddress(self, args): + idx = self.getOriginArgIdx(self.globalAddressArgIdx, args) + return args[idx] + + # args, captured kernel args in runtime + def getGlobalDims(self, args): + shape = [] + for e in self.globalDimsArgIdx: + t = 1 + # < 0 means folded arg or constant (-1 - value) + # -1 means extended dim which is 1, -2 means folded arg with constant 1 (-1 - value) + if e == -1: + t = 1 + elif e < 0 and e != -1: + t = -e - 1 + else: + idx = self.getOriginArgIdx(e, args) + t = args[idx] + shape.append(t) + return shape + + def getGlobalStrides(self, args): + t_globalDims = [int(e) for e in self.getGlobalDims(args)] + t_globalStridesArgIdx = self.globalStridesArgIdx.copy() + strides_in_elements = [] + # todo: get all stride from backend even in extended mode + for i in range(self.tensorRank): + t = 1 + if t_globalStridesArgIdx[i] == -1: + for ii in range(i): + t *= t_globalDims[ii] + # -2 means the sride in arguments is folded constant 1, we don't use 1 because it can not be distinguished from index 1 + elif t_globalStridesArgIdx[i] < 0: + t = -1 - t_globalStridesArgIdx[i] + else: + new_idx = self.getOriginArgIdx(t_globalStridesArgIdx[i], args) + t = args[new_idx] + + strides_in_elements.append(t) + + strides_in_elements = strides_in_elements[1:] + strides_in_bytes = [e * self.bytes_from_type(self.tensorDataType) for e in strides_in_elements] + return strides_in_bytes + + def getOriginArgIdx(self, idx, args): + if self.ids_of_folded_args: + ids_before_folding_arg = [i for i in range(len(args)) if i not in self.ids_of_folded_args] + return ids_before_folding_arg[idx] + else: + return idx + + def tensormap(self, args): + return driver.utils.cuTensorMapEncodeTiled( + self.getTensorMapDataType(), + self.getTensorRank(), + self.getGlobalAddress(args), + self.getGlobalDims(args), + self.getGlobalStrides(args), + self.getBoxDims(), + self.getElementStrides(), + self.getInterleave(), + self.getSwizzle(), + self.getL2Promotion(), + self.getOobFill(), + ) + + # make hashable to use as partial key in cache + def __hash__(self): + return hash((self.ids_of_folded_args, self.globalAddressArgIdx, tuple(self.globalDimsArgIdx), + tuple(self.globalStridesArgIdx), self.tensorDataType, self.tensorRank, tuple(self.boxDims), + tuple(self.elementStrides), self.interleave, self.swizzle, self.l2Promotion, self.oobFill)) + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return False + return (self.ids_of_folded_args, self.globalAddressArgIdx, self.globalDimsArgIdx, self.globalStridesArgIdx, + self.tensorDataType, self.tensorRank, self.boxDims, self.elementStrides, self.interleave, self.swizzle, + self.l2Promotion, + self.oobFill) == (other.ids_of_folded_args, other.globalAddressArgIdx, other.globalDimsArgIdx, + other.globalStridesArgIdx, other.tensorDataType, other.tensorRank, other.boxDims, + other.elementStrides, other.interleave, other.swizzle, other.l2Promotion, + other.oobFill) + + +# ---------------------------------------------------- + + +def _path_to_binary(binary: str): + paths = [ + os.environ.get(f"TRITON_{binary.upper()}_PATH", ""), + os.path.join(os.path.dirname(__file__), "bin", binary), + ] + + for p in paths: + bin = p.split(" ")[0] + if os.path.exists(bin) and os.path.isfile(bin): + result = subprocess.check_output([bin, "--version"], stderr=subprocess.STDOUT) + if result is not None: + version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE) + if version is not None: + return p, version.group(1) + raise RuntimeError(f"Cannot find {binary}") + + +@functools.lru_cache() +def ptx_get_version(cuda_version) -> int: + ''' + Get the highest PTX version supported by the current CUDA driver. + ''' + assert isinstance(cuda_version, str) + major, minor = map(int, cuda_version.split('.')) + if major == 12: + return 80 + minor + if major == 11: + return 70 + minor + if major == 10: + return 63 + minor + raise RuntimeError("Triton only support CUDA 10.0 or higher") + + +@dataclass(frozen=True) +class CUDAOptions: + num_warps: int = 4 + num_ctas: int = 1 + num_stages: int = 3 + cluster_dims: tuple = (1, 1, 1) + ptx_version: int = None + enable_warp_specialization: bool = False + enable_persistent: bool = False + optimize_epilogue: bool = False + enable_fp_fusion: bool = True + allow_fp8e4nv: bool = False + max_num_imprecise_acc_default: bool = None + extern_libs: dict = None + debug: bool = False + + def __post_init__(self): + default_libdir = Path(__file__).parent / 'lib' + extern_libs = dict() if self.extern_libs is None else dict(self.extern_libs) + if not extern_libs.get('libdevice', None): + extern_libs['libdevice'] = str(default_libdir / 'libdevice.10.bc') + object.__setattr__(self, 'extern_libs', tuple(extern_libs.items())) + assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \ + "num_warps must be a power of 2" + + def hash(self): + key = '_'.join([f'{name}-{val}' for name, val in self.__dict__.items()]) + return hashlib.md5(key.encode("utf-8")).hexdigest() + + +class CUDABackend(BaseBackend): + + @staticmethod + def supports_target(target: tuple): + return target[0] == 'cuda' + + def __init__(self, target: tuple) -> None: + super().__init__(target) + self.capability = target[1] + assert isinstance(self.capability, int) + + def parse_options(self, opts) -> Any: + args = {k: opts[k] for k in CUDAOptions.__dataclass_fields__.keys() if k in opts} + args["allow_fp8e4nv"] = self.capability >= 89 + args["max_num_imprecise_acc_default"] = 2**30 if self.capability == 90 else 0 + return CUDAOptions(**args) + + def load_dialects(self, ctx): + nvidia.load_dialects(ctx) + + @staticmethod + def make_ttir(mod, metadata, opt): + pm = ir.pass_manager(mod.context) + pm.enable_debug() + passes.common.add_inliner(pm) + passes.ttir.add_combine(pm) + passes.common.add_canonicalizer(pm) + passes.ttir.add_reorder_broadcast(pm) + passes.common.add_cse(pm) + passes.common.add_licm(pm) + passes.common.add_symbol_dce(pm) + pm.run(mod) + return mod + + @staticmethod + def make_ttgir(mod, metadata, opt, capability): + cluster_info = nvidia.ClusterInfo() + if opt.cluster_dims is not None: + cluster_info.clusterDimX = opt.cluster_dims[0] + cluster_info.clusterDimY = opt.cluster_dims[1] + cluster_info.clusterDimZ = opt.cluster_dims[2] + # TTIR -> TTGIR + pm = ir.pass_manager(mod.context) + pm.enable_debug() + passes.ttir.add_convert_to_ttgpuir(pm, opt.num_warps, 32, opt.num_ctas, capability) + # optimize TTGIR + passes.ttgpuir.add_coalesce(pm) + # TODO(Qingyi): Move PlanCTAPass to the front of CoalescePass + nvidia.passes.ttnvgpuir.add_plan_cta(pm, cluster_info) + nvidia.passes.ttgpuir.add_rewrite_tensor_pointer(pm, capability) + nvidia.passes.ttnvgpuir.add_plan_cta(pm, cluster_info) + passes.ttgpuir.add_remove_layout_conversions(pm) + passes.ttgpuir.add_optimize_thread_locality(pm) + passes.ttgpuir.add_accelerate_matmul(pm, capability) + passes.ttgpuir.add_remove_layout_conversions(pm) + if opt.optimize_epilogue: + passes.ttgpuir.add_optimize_epilogue(pm) + passes.ttgpuir.add_optimize_dot_operands(pm) + passes.common.add_cse(pm) + # `num_warps` does not mean the total number of warps of a CTA when + # warp specialization is enabled. + # it's the responsibility of the compiler to figure out the exact + # `num_warps` to use. + # TODO: support the case where `num_warps` from user is not 4. + ws_enabled = False + if capability // 10 >= 9 and opt.enable_warp_specialization and opt.num_warps == 4: + nvidia.passes.ttnvgpuir.add_wsfeasibility_checking(pm, capability) + pm.run(mod) + ws_enabled = nvidia.passes.ttnvgpuir.is_ws_supported(mod) + pm = ir.pass_manager(mod.context) + pm.enable_debug() + metadata["ws_enabled"] = ws_enabled + if ws_enabled: + nvidia.passes.ttnvgpuir.add_wsdecomposing(pm, capability) + nvidia.passes.ttnvgpuir.add_wspipeline(pm, opt.num_stages, opt.num_warps, capability) + nvidia.passes.ttnvgpuir.add_wsmutex(pm, capability) + nvidia.passes.ttnvgpuir.add_wsmaterialization(pm, capability) + passes.common.add_licm(pm) + passes.common.add_cse(pm) + else: + passes.ttgpuir.add_pipeline(pm, opt.num_stages, opt.num_warps, opt.num_ctas, capability) + nvidia.passes.ttnvgpuir.add_materialize_load_store(pm, opt.num_warps, capability) + if capability // 10 <= 8: + passes.ttgpuir.add_prefetch(pm) + passes.ttgpuir.add_optimize_dot_operands(pm) + passes.ttgpuir.add_remove_layout_conversions(pm) + passes.ttgpuir.add_decompose_conversions(pm) + nvidia.passes.ttnvgpuir.add_wsfixup_missing_attrs(pm) + passes.ttgpuir.add_reorder_instructions(pm) + passes.common.add_cse(pm) + passes.common.add_symbol_dce(pm) + if capability // 10 >= 9: + nvidia.passes.ttnvgpuir.add_fence_insertion(pm) + nvidia.passes.ttnvgpuir.add_wsfixup_missing_attrs(pm) + passes.common.add_canonicalizer(pm) + pm.run(mod) + metadata["cluster_dims"] = (cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ) + return mod + + @staticmethod + def make_llir(src, metadata, options, capability): + # warp-specialization mutates num_warps + num_warp_groups = src.get_int_attr("triton_gpu.num-warp-groups-per-cta") + if num_warp_groups is not None: + metadata["num_warps"] *= num_warp_groups + mod = src + # TritonGPU -> LLVM-IR (MLIR) + tma_infos = nvidia.TMAInfos() + pm = ir.pass_manager(mod.context) + pm.enable_debug() + passes.convert.add_scf_to_cf(pm) + passes.convert.add_index_to_llvmir(pm) + nvidia.passes.ttgpuir.add_to_llvmir(pm, capability, tma_infos) + if metadata["ws_enabled"]: + passes.common.add_licm(pm) + passes.common.add_cse(pm) + nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm) + passes.convert.add_arith_to_llvmir(pm) + passes.common.add_canonicalizer(pm) + passes.common.add_cse(pm) + passes.common.add_symbol_dce(pm) + if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0": + passes.llvmir.add_di_scope(pm) + pm.run(mod) + # LLVM-IR (MLIR) -> LLVM-IR (LLVM) + llvm.init_targets() + context = llvm.context() + llvm_mod = llvm.to_module(mod, context) + nvidia.set_nvvm_reflect_ftz(llvm_mod) + if options.extern_libs: + for name, path in options.extern_libs: + llvm.link_extern_lib(llvm_mod, path) + llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3) + # Get some metadata + if len(tma_infos) > 0: + metadata["tensormaps_info"] = parse_tma_info(tma_infos, metadata["ids_of_folded_args"]) + for i, _ in enumerate(metadata["tensormaps_info"]): + metadata["tensormaps_info"][i].ids_of_folded_args = metadata["ids_of_folded_args"] + metadata["ids_of_tensormaps"] = get_ids_of_tensormaps(metadata.get("tensormaps_info", None)) + metadata["shared"] = src.get_int_attr("triton_gpu.shared") + ret = str(llvm_mod) + del llvm_mod + del context + return ret + + @staticmethod + def make_ptx(src, metadata, opt, capability): + proc = 'sm_90a' if capability == 90 else f'sm_{capability}' + ret = llvm.translate_to_asm(src, 'nvptx64-nvidia-cuda', proc, '', ['nvptx-short-ptr'], opt.enable_fp_fusion, + False) + # Find kernel names (there should only be one) + names = re.findall(r".visible .entry ([a-zA-Z_][a-zA-Z0-9_]*)", ret) + assert len(names) == 1 + metadata["name"] = names[0] + # post-process + ptx_version = opt.ptx_version + if ptx_version is None: + _, cuda_version = _path_to_binary("ptxas") + ptx_version = ptx_get_version(cuda_version) + ptx_version = f'{ptx_version//10}.{ptx_version%10}' + ret = re.sub(r'\.version \d+\.\d+', f'.version {ptx_version}', ret, flags=re.MULTILINE) + # Remove the debug flag that prevents ptxas from optimizing the code + ret = re.sub(r",\s*debug|debug,\s*", "", ret) + return ret + + @staticmethod + def make_cubin(src, metadata, opt, capability): + ptxas, _ = _path_to_binary("ptxas") + with tempfile.NamedTemporaryFile(delete=False, mode='w', suffix='.ptx') as fsrc, \ + tempfile.NamedTemporaryFile(delete=False, mode='r', suffix='.log') as flog: + fsrc.write(src) + fsrc.flush() + fbin = fsrc.name + '.o' + + line_info = '' if os.environ.get('TRITON_DISABLE_LINE_INFO') else ' -lineinfo' + fmad = '' if opt.enable_fp_fusion else ' --fmad=false' + suffix = 'a ' if capability == 90 else ' ' + cmd = f'{ptxas}{line_info}{fmad} -v --gpu-name=sm_{capability}{suffix}{fsrc.name} -o {fbin} 2> {flog.name}' + + try: + subprocess.run(cmd, shell=True, check=True) + except subprocess.CalledProcessError as e: + with open(flog.name) as log_file: + log = log_file.read() + if e.returncode == 255: + raise RuntimeError(f'Internal Triton PTX codegen error: \n{log}') + elif e.returncode == 128 + signal.SIGSEGV: + raise RuntimeError( + f'Please run `ptxas {fsrc.name}` to confirm that this is a bug in `ptxas`\n{log}') + else: + raise RuntimeError(f'`ptxas` failed with error code {e.returncode}: \n{log}') + finally: + if os.path.exists(fsrc.name): + os.remove(fsrc.name) + if os.path.exists(flog.name): + os.remove(flog.name) + + with open(fbin, 'rb') as f: + cubin = f.read() + if os.path.exists(fbin): + os.remove(fbin) + return cubin + + def add_stages(self, stages, options): + stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options) + stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, self.capability) + stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, self.capability) + stages["ptx"] = lambda src, metadata: self.make_ptx(src, metadata, options, self.capability) + stages["cubin"] = lambda src, metadata: self.make_cubin(src, metadata, options, self.capability) + + @functools.lru_cache() + def hash(self): + version = subprocess.check_output([_path_to_binary("ptxas")[0], "--version"]) + return f'{version}-{self.capability}' diff --git a/python/triton/runtime/backends/cuda.c b/third_party/cuda/backend/driver.c similarity index 100% rename from python/triton/runtime/backends/cuda.c rename to third_party/cuda/backend/driver.c diff --git a/third_party/cuda/backend/driver.py b/third_party/cuda/backend/driver.py new file mode 100644 index 0000000000..243a99d8f0 --- /dev/null +++ b/third_party/cuda/backend/driver.py @@ -0,0 +1,402 @@ +import os +import hashlib +import tempfile +from pathlib import Path +from triton.runtime.build import _build +from triton.runtime.cache import get_cache_manager +from triton.backends.driver import GPUDriver + +dirname = os.path.dirname(os.path.realpath(__file__)) +include_dir = [os.path.join(dirname, "include")] +library_dir = [os.path.join(dirname, "lib")] +libraries = ['cuda'] + + +def compile_module_from_src(src, name): + key = hashlib.md5(src.encode("utf-8")).hexdigest() + cache = get_cache_manager(key) + cache_path = cache.get_file(f"{name}.so") + if cache_path is None: + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, "main.c") + with open(src_path, "w") as f: + f.write(src) + so = _build(name, src_path, tmpdir, library_dir, include_dir, libraries) + with open(so, "rb") as f: + cache_path = cache.put(f.read(), f"{name}.so", binary=True) + import importlib.util + spec = importlib.util.spec_from_file_location(name, cache_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +# ------------------------ +# Utils +# ------------------------ + + +class CudaUtils(object): + + def __new__(cls): + if not hasattr(cls, "instance"): + cls.instance = super(CudaUtils, cls).__new__(cls) + return cls.instance + + def __init__(self): + mod = compile_module_from_src(Path(os.path.join(dirname, "driver.c")).read_text(), "cuda_utils") + self.load_binary = mod.load_binary + self.get_device_properties = mod.get_device_properties + self.CUtensorMapDataType = mod.CUtensorMapDataType + self.CUtensorMapInterleave = mod.CUtensorMapInterleave + self.CUtensorMapSwizzle = mod.CUtensorMapSwizzle + self.CUtensorMapL2promotion = mod.CUtensorMapL2promotion + self.CUtensorMapFloatOOBfill = mod.CUtensorMapFloatOOBfill + self.cuTensorMapEncodeTiled = mod.cuTensorMapEncodeTiled + self.cuMemAlloc = mod.cuMemAlloc + self.cuMemcpyHtoD = mod.cuMemcpyHtoD + self.cuMemFree = mod.cuMemFree + self.cuOccupancyMaxActiveClusters = mod.cuOccupancyMaxActiveClusters + + +# ------------------------ +# Launcher +# ------------------------ + + +def ty_to_cpp(ty): + if ty[0] == '*': + return "CUdeviceptr" + return { + "i1": "int32_t", + "i8": "int8_t", + "i16": "int16_t", + "i32": "int32_t", + "i64": "int64_t", + "u32": "uint32_t", + "u64": "uint64_t", + "fp16": "float", + "bf16": "float", + "fp32": "float", + "f32": "float", + "fp64": "double", + }[ty] + + +def generate_cu_signature(constants, signature, ids): + # CUtensorMap*s are always the last arguments + num_regular_signatures = max(signature.keys()) + 1 if len(signature) > 0 else 0 + if ids["ids_of_tensormaps"] is not None: + for i, _ in enumerate(ids["ids_of_tensormaps"]): + signature[num_regular_signatures + i] = '*CUtensorMap' + return signature, num_regular_signatures + + +def make_launcher(constants, signature, ids): + # Record the end of regular arguments; + # subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA. + signature, desc_start_idx = generate_cu_signature(constants, signature, ids) + arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) + + def _extracted_type(ty): + if ty[0] == '*': + return "PyObject*" + return { + 'i1': 'int32_t', + 'i32': 'int32_t', + 'i64': 'int64_t', + 'u32': 'uint32_t', + 'u64': 'uint64_t', + 'fp16': 'float', + 'bf16': 'float', + 'fp32': 'float', + 'f32': 'float', + 'fp64': 'double', + }[ty] + + def format_of(ty): + return { + "PyObject*": "O", + "float": "f", + "double": "d", + "long": "l", + "uint32_t": "I", + "int32_t": "i", + "uint64_t": "K", + "int64_t": "L", + }[ty] + + format = "iiiiiiiiiKKOOO" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) + + # generate glue code + folded_without_constexprs = [c for c in ids['ids_of_folded_args'] if c not in ids['ids_of_const_exprs']] + params = [ + i for i in signature.keys() + if i >= desc_start_idx or (i not in constants and i not in folded_without_constexprs) + ] + src = f""" +#include \"cuda.h\" +#include +#include +#include + +static inline void gpuAssert(CUresult code, const char *file, int line) +{{ + if (code != CUDA_SUCCESS) + {{ + const char* prefix = "Triton Error [CUDA]: "; + const char* str; + cuGetErrorString(code, &str); + char err[1024] = {{0}}; + strcat(err, prefix); + strcat(err, str); + PyGILState_STATE gil_state; + gil_state = PyGILState_Ensure(); + PyErr_SetString(PyExc_RuntimeError, err); + PyGILState_Release(gil_state); + }} +}} + +#define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }} + +typedef CUresult (*cuLaunchKernelEx_t)(const CUlaunchConfig* config, CUfunction f, void** kernelParams, void** extra); + +static cuLaunchKernelEx_t getLaunchKernelExHandle() {{ + // Open the shared library + void* handle = dlopen("libcuda.so", RTLD_LAZY); + if (!handle) {{ + PyErr_SetString(PyExc_RuntimeError, "Failed to open libcuda.so"); + return NULL; + }} + // Clear any existing error + dlerror(); + cuLaunchKernelEx_t cuLaunchKernelExHandle = (cuLaunchKernelEx_t)dlsym(handle, "cuLaunchKernelEx"); + // Check for errors + const char *dlsym_error = dlerror(); + if (dlsym_error) {{ + PyErr_SetString(PyExc_RuntimeError, "Failed to retrieve cuLaunchKernelEx from libcuda.so"); + return NULL; + }} + return cuLaunchKernelExHandle; +}} + +static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ + void *params[] = {{ {', '.join(f"&arg{i}" for i in params)} }}; + if (gridX*gridY*gridZ > 0) {{ + if (num_ctas == 1) {{ + CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0)); + }} else {{ + CUlaunchAttribute launchAttr[2]; + launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; + launchAttr[0].value.clusterDim.x = clusterDimX; + launchAttr[0].value.clusterDim.y = clusterDimY; + launchAttr[0].value.clusterDim.z = clusterDimZ; + launchAttr[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE; + launchAttr[1].value.clusterSchedulingPolicyPreference = CU_CLUSTER_SCHEDULING_POLICY_SPREAD; + CUlaunchConfig config; + config.gridDimX = gridX * clusterDimX; + config.gridDimY = gridY * clusterDimY; + config.gridDimZ = gridZ * clusterDimZ; + config.blockDimX = 32 * num_warps; + config.blockDimY = 1; + config.blockDimZ = 1; + config.sharedMemBytes = shared_memory; + config.hStream = stream; + config.attrs = launchAttr; + config.numAttrs = 2; + static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL; + if (cuLaunchKernelExHandle == NULL) {{ + cuLaunchKernelExHandle = getLaunchKernelExHandle(); + }} + CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0)); + }} + }} +}} + +typedef struct _DevicePtrInfo {{ + CUdeviceptr dev_ptr; + bool valid; +}} DevicePtrInfo; + +static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{ + DevicePtrInfo ptr_info; + ptr_info.dev_ptr = 0; + ptr_info.valid = true; + if (PyLong_Check(obj)) {{ + ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(obj); + return ptr_info; + }} + if (obj == Py_None) {{ + // valid nullptr + return ptr_info; + }} + PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); + if(ptr){{ + PyObject *empty_tuple = PyTuple_New(0); + PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); + Py_DECREF(empty_tuple); + Py_DECREF(ptr); + if (!PyLong_Check(ret)) {{ + PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); + ptr_info.valid = false; + return ptr_info; + }} + ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(ret); + if(!ptr_info.dev_ptr) + return ptr_info; + uint64_t dev_ptr; + int status = cuPointerGetAttribute(&dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr); + if (status == CUDA_ERROR_INVALID_VALUE) {{ + PyErr_Format(PyExc_ValueError, + "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx); + ptr_info.valid = false; + }} + ptr_info.dev_ptr = dev_ptr; + Py_DECREF(ret); // Thanks ChatGPT! + return ptr_info; + }} + PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); + ptr_info.valid = false; + return ptr_info; +}} + +static PyObject* launch(PyObject* self, PyObject* args) {{ + int gridX, gridY, gridZ; + uint64_t _stream; + uint64_t _function; + int num_warps; + int num_ctas; + int clusterDimX; + int clusterDimY; + int clusterDimZ; + int shared_memory; + PyObject *launch_enter_hook = NULL; + PyObject *launch_exit_hook = NULL; + PyObject *compiled_kernel = NULL; + {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} + if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &num_ctas, &clusterDimX, &clusterDimY, &clusterDimZ, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel{', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''})) {{ + return NULL; + }} + + if (launch_enter_hook != Py_None && !PyObject_CallObject(launch_enter_hook, args)) {{ + return NULL; + }} + + + // raise exception asap + {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; + Py_BEGIN_ALLOW_THREADS; + _launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function{', ' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items()) if len(signature) > 0 else ''}); + Py_END_ALLOW_THREADS; + if (PyErr_Occurred()) {{ + return NULL; + }} + + if (launch_exit_hook != Py_None && !PyObject_CallObject(launch_exit_hook, args)) {{ + return NULL; + }} + + // return None + Py_INCREF(Py_None); + return Py_None; +}} + +static PyMethodDef ModuleMethods[] = {{ + {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, + {{NULL, NULL, 0, NULL}} // sentinel +}}; + +static struct PyModuleDef ModuleDef = {{ + PyModuleDef_HEAD_INIT, + \"__triton_launcher\", + NULL, //documentation + -1, //size + ModuleMethods +}}; + +PyMODINIT_FUNC PyInit___triton_launcher(void) {{ + PyObject *m = PyModule_Create(&ModuleDef); + if(m == NULL) {{ + return NULL; + }} + PyModule_AddFunctions(m, ModuleMethods); + return m; +}} +""" + return src + + +class CudaLauncher(object): + + def __init__(self, src, metadata): + ids = { + "ids_of_tensormaps": metadata.get("ids_of_tensormaps", tuple()), "ids_of_folded_args": + metadata.get("ids_of_folded_args", + tuple()), "ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple() + } + constants = src.constants if hasattr(src, "constants") else dict() + enable_warp_specialization = False + src = make_launcher(constants, src.signature, ids) + mod = compile_module_from_src(src, "__triton_launcher") + self.launch = mod.launch + + def __call__(self, *args, **kwargs): + self.launch(*args, **kwargs) + + +# ------------------------ +# Tensor Map +# ------------------------ + + +class TensorMapManager: + + def __init__(self, utils): + self.tensormaps_device = {} + self.utils = utils + + def __getitem__(self, key: tuple): + if key in self.tensormaps_device: + return int(self.tensormaps_device[key]) + else: + (e, args) = key + t_tensormap = e.tensormap(args) + TENSORMAP_SIZE_IN_BYTES = 128 + t_tensormap_device = self.utils.cuMemAlloc(TENSORMAP_SIZE_IN_BYTES) + self.utils.cuMemcpyHtoD(t_tensormap_device, t_tensormap, TENSORMAP_SIZE_IN_BYTES) + self.tensormaps_device[key] = t_tensormap_device + return int(self.tensormaps_device[key]) + + def __del__(self): + for _, v in self.tensormaps_device.items(): + self.utils.cuMemFree(v) + + +class CudaDriver(GPUDriver): + + def __init__(self): + self.utils = CudaUtils() # TODO: make static + self.tensormap_manager = TensorMapManager(self.utils) # TODO: make static + self.binary_ext = "cubin" + self.launcher_cls = CudaLauncher + super().__init__() + + def get_current_target(self): + device = self.get_current_device() + capability = self.get_device_capability(device) + capability = capability[0] * 10 + capability[1] + return ("cuda", capability) + + @staticmethod + def is_active(): + import torch + return torch.version.hip is None and not torch.xpu.is_available() + + def assemble_tensormap_to_arg(self, tensormaps_info, args): + args_with_tma = list(args) + if tensormaps_info is not None: + # tuple for hashable + args_ptr = tuple([arg.data_ptr() if hasattr(arg, 'data_ptr') else arg for arg in args]) + for i, e in enumerate(tensormaps_info): + args_with_tma.append(self.tensormap_manager[(e, args_ptr)]) + return args_with_tma diff --git a/python/triton/third_party/cuda/include/cuda.h b/third_party/cuda/backend/include/cuda.h similarity index 100% rename from python/triton/third_party/cuda/include/cuda.h rename to third_party/cuda/backend/include/cuda.h diff --git a/python/triton/third_party/cuda/lib/libdevice.10.bc b/third_party/cuda/backend/lib/libdevice.10.bc similarity index 100% rename from python/triton/third_party/cuda/lib/libdevice.10.bc rename to third_party/cuda/backend/lib/libdevice.10.bc diff --git a/python/src/nvidia.cc b/third_party/cuda/triton_nvidia.cc similarity index 92% rename from python/src/nvidia.cc rename to third_party/cuda/triton_nvidia.cc index 100b2118aa..c1f444dd93 100644 --- a/python/src/nvidia.cc +++ b/third_party/cuda/triton_nvidia.cc @@ -1,6 +1,5 @@ #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" -#include "mlir/Target/LLVMIR/Dialect/GENX/GENXToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" #include "passes.h" #include "triton/Conversion/NVGPUToLLVM/Passes.h" @@ -26,7 +25,7 @@ void init_triton_nvidia_passes_ttgpuir(py::module &&m) { // nvidia-specificontext m.def("add_to_llvmir", [](mlir::PassManager &pm, int32_t capability, mlir::triton::gpu::TMAMetadataTy *tmaMetadata) { - pm.addPass(createConvertTritonGPUToLLVMPass(capability, mlir::triton::GENX, + pm.addPass(createConvertTritonGPUToLLVMPass(capability, mlir::triton::NVVM, tmaMetadata)); }); } @@ -60,7 +59,7 @@ void init_triton_nvidia_passes_ttnvgpuir(py::module &&m) { }); } -void init_triton_nvidia(py::module &&m) { +void init_triton_nvidia(py::module &&m){ auto passes = m.def_submodule("passes"); init_triton_nvidia_passes_ttgpuir(passes.def_submodule("ttgpuir")); init_triton_nvidia_passes_ttnvgpuir(passes.def_submodule("ttnvgpuir")); @@ -110,22 +109,10 @@ void init_triton_nvidia(py::module &&m) { registry.insert(); mlir::registerNVVMDialectTranslation(registry); - mlir::registerGENXDialectTranslation(registry); context.appendDialectRegistry(registry); context.loadAllAvailableDialects(); }); - // init llvm - m.def("init_llvm", []() { - static std::once_flag init_flag; - std::call_once(init_flag, []() { - LLVMInitializeNVPTXTargetInfo(); - LLVMInitializeNVPTXTarget(); - LLVMInitializeNVPTXTargetMC(); - LLVMInitializeNVPTXAsmPrinter(); - }); - }); - // TODO: could be done in python if we had a generic interface to set metadata m.def("set_nvvm_reflect_ftz", [](llvm::Module *mod) { // please check https://llvm.org/docs/NVPTXUsage.html#reflection-parameters diff --git a/third_party/intel_xpu_backend b/third_party/intel_xpu_backend deleted file mode 160000 index d05dc79dad..0000000000 --- a/third_party/intel_xpu_backend +++ /dev/null @@ -1 +0,0 @@ -Subproject commit d05dc79dad638b8ebbacfef44886f568b5885fc3 diff --git a/third_party/triton_shared b/third_party/triton_shared deleted file mode 160000 index 1759426a09..0000000000 --- a/third_party/triton_shared +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 1759426a098c6a2e5621311613832d414421a43d diff --git a/third_party/xpu/CMakeLists.txt b/third_party/xpu/CMakeLists.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/third_party/xpu/backend/__init__.py b/third_party/xpu/backend/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/third_party/xpu/backend/bin/spirv-dis b/third_party/xpu/backend/bin/spirv-dis new file mode 100755 index 0000000000..50456c5321 Binary files /dev/null and b/third_party/xpu/backend/bin/spirv-dis differ diff --git a/python/triton/compiler/utils.py b/third_party/xpu/backend/compiler.py similarity index 53% rename from python/triton/compiler/utils.py rename to third_party/xpu/backend/compiler.py index 48233afedd..ac99a225f5 100644 --- a/python/triton/compiler/utils.py +++ b/third_party/xpu/backend/compiler.py @@ -1,3 +1,19 @@ +from triton.backends.compiler import BaseBackend +from triton._C.libtriton import ir, passes, llvm, xpu +from triton.runtime import driver +from dataclasses import dataclass +import functools +from typing import Any +import hashlib +import re +import tempfile +import signal +import os +import subprocess +from pathlib import Path + +# ------------- TMA stuff ----------------# +# # Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining @@ -19,19 +35,6 @@ # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE # SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -from __future__ import annotations - -from ..runtime import driver - - -def generate_cu_signature(constants, signature, ids): - # CUtensorMap*s are always the last arguments - num_regular_signatures = max(signature.keys()) + 1 if len(signature) > 0 else 0 - if ids["ids_of_tensormaps"] is not None: - for i, _ in enumerate(ids["ids_of_tensormaps"]): - signature[num_regular_signatures + i] = '*CUtensorMap' - return signature, num_regular_signatures - def dummy_tensormaps_info(n=2): ret = [] @@ -280,3 +283,232 @@ def __eq__(self, other): other.globalStridesArgIdx, other.tensorDataType, other.tensorRank, other.boxDims, other.elementStrides, other.interleave, other.swizzle, other.l2Promotion, other.oobFill) + + +# ---------------------------------------------------- + + +def _path_to_binary(binary: str): + paths = [ + os.environ.get(f"TRITON_{binary.upper()}_PATH", ""), + os.path.join(os.path.dirname(__file__), "bin", binary), + ] + + for p in paths: + bin = p.split(" ")[0] + if os.path.exists(bin) and os.path.isfile(bin): + result = subprocess.check_output([bin, "--version"], stderr=subprocess.STDOUT) + if result is not None: + version = re.search(r".*SPIRV-Tools v(\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE) + if version is not None: + return p, version.group(1) + raise RuntimeError(f"Cannot find {binary}") + + +@functools.lru_cache() +def ptx_get_version(cuda_version) -> int: + ''' + Get the highest PTX version supported by the current CUDA driver. + ''' + assert isinstance(cuda_version, str) + major, minor = map(int, cuda_version.split('.')) + if major == 12: + return 80 + minor + if major == 11: + return 70 + minor + if major == 10: + return 63 + minor + raise RuntimeError("Triton only support CUDA 10.0 or higher") + + +@dataclass(frozen=True) +class XPUOptions: + num_warps: int = 4 + num_ctas: int = 1 + num_stages: int = 2 + cluster_dims: tuple = (1, 1, 1) + ptx_version: int = None + enable_warp_specialization: bool = False + enable_persistent: bool = False + optimize_epilogue: bool = False + enable_fp_fusion: bool = True + allow_fp8e4nv: bool = False + max_num_imprecise_acc_default: bool = None + extern_libs: dict = None + debug: bool = False + + def __post_init__(self): + default_libdir = Path(__file__).parent / 'lib' + extern_libs = dict() if self.extern_libs is None else dict(self.extern_libs) + if not extern_libs.get('libdevice', None): + extern_libs['libdevice'] = str(default_libdir / 'libsycl-spir64-unknown-unknown.bc') + object.__setattr__(self, 'extern_libs', tuple(extern_libs.items())) + assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \ + "num_warps must be a power of 2" + + def hash(self): + key = '_'.join([f'{name}-{val}' for name, val in self.__dict__.items()]) + return hashlib.md5(key.encode("utf-8")).hexdigest() + + +class XPUBackend(BaseBackend): + + @staticmethod + def supports_target(target: tuple): + return target[0] == 'xpu' + + def __init__(self, target: tuple) -> None: + super().__init__(target) + self.capability = target[1] + assert isinstance(self.capability, int) + + def parse_options(self, opts) -> Any: + args = {k: opts[k] for k in XPUOptions.__dataclass_fields__.keys() if k in opts} + args["allow_fp8e4nv"] = True + args["max_num_imprecise_acc_default"] = 2**30 if self.capability == 90 else 0 + return XPUOptions(**args) + + def load_dialects(self, ctx): + xpu.load_dialects(ctx) + + @staticmethod + def make_ttir(mod, metadata, opt): + pm = ir.pass_manager(mod.context) + pm.enable_debug() + passes.common.add_inliner(pm) + passes.ttir.add_combine(pm) + passes.common.add_canonicalizer(pm) + passes.ttir.add_reorder_broadcast(pm) + passes.common.add_cse(pm) + passes.common.add_licm(pm) + passes.common.add_symbol_dce(pm) + pm.run(mod) + return mod + + @staticmethod + def make_ttgir(mod, metadata, opt, capability): + cluster_info = xpu.ClusterInfo() + if opt.cluster_dims is not None: + cluster_info.clusterDimX = opt.cluster_dims[0] + cluster_info.clusterDimY = opt.cluster_dims[1] + cluster_info.clusterDimZ = opt.cluster_dims[2] + # TTIR -> TTGIR + pm = ir.pass_manager(mod.context) + pm.enable_debug() + passes.ttir.add_convert_to_ttgpuir(pm, opt.num_warps, 32, opt.num_ctas, capability) + # optimize TTGIR + passes.ttgpuir.add_coalesce(pm) + # TODO(Qingyi): Move PlanCTAPass to the front of CoalescePass + xpu.passes.ttnvgpuir.add_plan_cta(pm, cluster_info) + xpu.passes.ttgpuir.add_rewrite_tensor_pointer(pm, capability) + xpu.passes.ttnvgpuir.add_plan_cta(pm, cluster_info) + passes.ttgpuir.add_remove_layout_conversions(pm) + passes.ttgpuir.add_optimize_thread_locality(pm) + passes.ttgpuir.add_accelerate_matmul(pm, capability) + passes.ttgpuir.add_remove_layout_conversions(pm) + if opt.optimize_epilogue: + passes.ttgpuir.add_optimize_epilogue(pm) + passes.ttgpuir.add_optimize_dot_operands(pm) + passes.common.add_cse(pm) + # `num_warps` does not mean the total number of warps of a CTA when + # warp specialization is enabled. + # it's the responsibility of the compiler to figure out the exact + # `num_warps` to use. + # TODO: support the case where `num_warps` from user is not 4. + ws_enabled = False + if capability // 10 >= 9 and opt.enable_warp_specialization and opt.num_warps == 4: + xpu.passes.ttnvgpuir.add_wsfeasibility_checking(pm, capability) + pm.run(mod) + ws_enabled = xpu.passes.ttnvgpuir.is_ws_supported(mod) + pm = ir.pass_manager(mod.context) + pm.enable_debug() + metadata["ws_enabled"] = ws_enabled + if ws_enabled: + xpu.passes.ttnvgpuir.add_wsdecomposing(pm, capability) + xpu.passes.ttnvgpuir.add_wspipeline(pm, opt.num_stages, opt.num_warps, capability) + xpu.passes.ttnvgpuir.add_wsmutex(pm, capability) + xpu.passes.ttnvgpuir.add_wsmaterialization(pm, capability) + passes.common.add_licm(pm) + passes.common.add_cse(pm) + else: + passes.ttgpuir.add_pipeline(pm, opt.num_stages, opt.num_warps, opt.num_ctas, capability) + xpu.passes.ttnvgpuir.add_materialize_load_store(pm, opt.num_warps, capability) + if capability // 10 <= 8: + passes.ttgpuir.add_prefetch(pm) + passes.ttgpuir.add_optimize_dot_operands(pm) + passes.ttgpuir.add_remove_layout_conversions(pm) + passes.ttgpuir.add_decompose_conversions(pm) + xpu.passes.ttnvgpuir.add_wsfixup_missing_attrs(pm) + passes.ttgpuir.add_reorder_instructions(pm) + passes.common.add_cse(pm) + passes.common.add_symbol_dce(pm) + if capability // 10 >= 9: + xpu.passes.ttnvgpuir.add_fence_insertion(pm) + xpu.passes.ttnvgpuir.add_wsfixup_missing_attrs(pm) + passes.common.add_canonicalizer(pm) + pm.run(mod) + metadata["cluster_dims"] = (cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ) + return mod + + @staticmethod + def make_llir(src, metadata, options, capability): + # warp-specialization mutates num_warps + num_warp_groups = src.get_int_attr("triton_gpu.num-warp-groups-per-cta") + if num_warp_groups is not None: + metadata["num_warps"] *= num_warp_groups + mod = src + # TritonGPU -> LLVM-IR (MLIR) + tma_infos = xpu.TMAInfos() + pm = ir.pass_manager(mod.context) + pm.enable_debug() + passes.convert.add_scf_to_cf(pm) + passes.convert.add_index_to_llvmir(pm) + xpu.passes.ttgpuir.add_to_llvmir(pm, capability, tma_infos) + if metadata["ws_enabled"]: + passes.common.add_licm(pm) + passes.common.add_cse(pm) + xpu.passes.ttnvgpuir.add_nvgpu_to_llvm(pm) + passes.convert.add_arith_to_llvmir(pm) + passes.common.add_canonicalizer(pm) + passes.common.add_cse(pm) + passes.common.add_symbol_dce(pm) + if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0": + passes.llvmir.add_di_scope(pm) + pm.run(mod) + # LLVM-IR (MLIR) -> LLVM-IR (LLVM) + llvm.init_targets() + context = llvm.context() + llvm_mod = llvm.to_module(mod, context) + llvm.set_spv_target_triple(llvm_mod) + if options.extern_libs: + for name, path in options.extern_libs: + llvm.link_extern_lib(llvm_mod, path) + llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3) + # Get some metadata + if len(tma_infos) > 0: + metadata["tensormaps_info"] = parse_tma_info(tma_infos, metadata["ids_of_folded_args"]) + for i, _ in enumerate(metadata["tensormaps_info"]): + metadata["tensormaps_info"][i].ids_of_folded_args = metadata["ids_of_folded_args"] + metadata["ids_of_tensormaps"] = get_ids_of_tensormaps(metadata.get("tensormaps_info", None)) + metadata["shared"] = src.get_int_attr("triton_gpu.shared") + ret = str(llvm_mod) + del llvm_mod + del context + return ret + + @staticmethod + def make_spv(src, metadata): + ret, name = llvm.translate_to_spirv(src) + metadata["name"] = name + return ret + + def add_stages(self, stages, options): + stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options) + stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, self.capability) + stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, self.capability) + stages["spv"] = lambda src, metadata: self.make_spv(src, metadata) + + @functools.lru_cache() + def hash(self): + version = subprocess.check_output([_path_to_binary("spirv-dis")[0], "--version"]) + return f'{version}-{self.capability}' diff --git a/third_party/xpu/backend/driver.py b/third_party/xpu/backend/driver.py new file mode 100644 index 0000000000..3ad923e9e3 --- /dev/null +++ b/third_party/xpu/backend/driver.py @@ -0,0 +1,752 @@ +import os +import hashlib +import tempfile +from pathlib import Path +from triton.runtime.build import _build +from triton.runtime.cache import get_cache_manager +from triton.backends.driver import DriverBase + + +import intel_extension_for_pytorch as ipex + + +dirname = os.getenv("ZE_PATH", default="/usr/local") +include_dir = [os.path.join(dirname, "include/level_zero")] +library_dir = [os.path.join(dirname, "lib")] +libraries = ['ze_loader'] + + +def compile_module_from_src(src, name): + key = hashlib.md5(src.encode("utf-8")).hexdigest() + cache = get_cache_manager(key) + cache_path = cache.get_file(f"{name}.so") + if cache_path is None: + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, "main.cpp") + with open(src_path, "w") as f: + f.write(src) + so = _build(name, src_path, tmpdir, library_dir, include_dir, libraries) + with open(so, "rb") as f: + cache_path = cache.put(f.read(), f"{name}.so", binary=True) + import importlib.util + spec = importlib.util.spec_from_file_location(name, cache_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +# ------------------------ +# Utils +# ------------------------ + + +class XPUUtils(object): + + def __new__(cls): + if not hasattr(cls, "instance"): + cls.instance = super(XPUUtils, cls).__new__(cls) + return cls.instance + + @staticmethod + def _generate_src(): + return """ + #include + #include + #include + #include + #include + #include + #include + #include + + #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION + #include + #include + + typedef struct l0_resc_handles { + ze_context_handle_t context; + ze_device_handle_t device; + ze_command_queue_handle_t queue; + ze_command_list_handle_t cmd_list; + }l0_resc_handles; + + std::unordered_map sycl_queue_map; + static ze_context_handle_t context = {nullptr}; + static ze_driver_handle_t driverHandle = {nullptr}; + static ze_event_pool_handle_t eventPoolHandle = {nullptr}; + + static std::vector devices; + + static inline void gpuAssert(ze_result_t code, const char *file, int line) + { + if (code != ZE_RESULT_SUCCESS) + { + const char* prefix = "Triton Error [ZE]: "; + std::string str = std::to_string(code); + char err[1024] = {0}; + strcat(err, prefix); + strcat(err, str.c_str()); + PyErr_SetString(PyExc_RuntimeError, err); + } + } + + #define ZE_CHECK(ans) { gpuAssert((ans), __FILE__, __LINE__); if(PyErr_Occurred()) return NULL; } + + static PyObject* getDeviceProperties(PyObject* self, PyObject* args){ + int device_id; + if(!PyArg_ParseTuple(args, "i", &device_id)) + return NULL; + + if (device_id > devices.size()) { + std::cout << "Device ID not found: " << device_id << std::endl; + return NULL; + } + + // Get device handle + ze_device_handle_t phDevice = devices[device_id]; + + // create a struct to hold device properties + ze_device_properties_t device_properties = {}; + device_properties.stype = ZE_STRUCTURE_TYPE_DEVICE_PROPERTIES; + zeDeviceGetProperties(phDevice, &device_properties); + + int multiprocessor_count = device_properties.numSlices * device_properties.numSubslicesPerSlice; + int sm_clock_rate = device_properties.coreClockRate; + + ze_device_compute_properties_t compute_properties = {}; + compute_properties.stype = ZE_STRUCTURE_TYPE_DEVICE_COMPUTE_PROPERTIES; + zeDeviceGetComputeProperties(phDevice, &compute_properties); + int max_shared_mem = compute_properties.maxSharedLocalMemory; + + uint32_t memoryCount = 0; + zeDeviceGetMemoryProperties(phDevice, &memoryCount, nullptr); + auto pMemoryProperties = new ze_device_memory_properties_t[memoryCount]; + for( uint32_t mem = 0; mem < memoryCount; ++mem ) + { + pMemoryProperties[mem].stype = ZE_STRUCTURE_TYPE_DEVICE_MEMORY_PROPERTIES; + pMemoryProperties[mem].pNext = nullptr; + } + zeDeviceGetMemoryProperties(phDevice, &memoryCount, pMemoryProperties); + // for( uint32_t mem = 0; mem < memoryCount; ++mem ) + // { + // std::cout << to_string( pMemoryProperties[ mem ] ) << std::endl; + // } + + int mem_clock_rate = pMemoryProperties[0].maxClockRate; + int mem_bus_width = pMemoryProperties[0].maxBusWidth; + + delete[] pMemoryProperties; + + return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i}", "max_shared_mem", max_shared_mem, + "multiprocessor_count", multiprocessor_count, + "sm_clock_rate", sm_clock_rate, + "mem_clock_rate", mem_clock_rate, + "mem_bus_width", mem_bus_width); + } + + static PyObject* loadBinary(PyObject* self, PyObject* args) { + const char* name; + int shared; + PyObject *py_bytes; + int device_id; + if(!PyArg_ParseTuple(args, "sSii", &name, &py_bytes, &shared, &device_id)) { + std::cout << "loadBinary arg parse failed" << std::endl; + return NULL; + } + + // uint8_t* data = (uint8_t*) PyBytes_AsString(py_bytes); + // int data_size = PyBytes_Size(py_bytes); + + if (device_id > devices.size()) { + std::cout << "Device ID not found: " << device_id << std::endl; + return NULL; + } + + ze_device_handle_t device = devices[device_id]; + + int32_t n_regs = 0; + int32_t n_spills = 0; + + ze_module_desc_t module_desc = {}; + module_desc.format = ZE_MODULE_FORMAT_IL_SPIRV; + module_desc.inputSize = PyBytes_Size(py_bytes); + module_desc.pInputModule = (uint8_t*) PyBytes_AsString(py_bytes); + ze_module_handle_t module; + // std::cout << "SPIRV binary size: " << module_desc.inputSize << std::endl; + ZE_CHECK(zeModuleCreate(context, device, &module_desc, &module, nullptr)); + + // std::cout << "loadBinary zeModuleCreated" << std::endl; + ze_kernel_desc_t kernel_desc = {}; + kernel_desc.pKernelName = name; + ze_kernel_handle_t fun; + ZE_CHECK(zeKernelCreate(module, &kernel_desc, &fun)); + + // std::cout << "loadBinary zeKernelCreated" << std::endl; + + if(PyErr_Occurred()) { + std::cout << "loadBinary error occurred" << std::endl; + return NULL; + } + + return Py_BuildValue("(KKii)", (uint64_t)module, (uint64_t)fun, n_regs, n_spills); + } + + bool update(sycl::queue sycl_queue) { + // Get l0-context + auto sycl_context = sycl_queue.get_context(); + ze_context_handle_t hCtxt = get_native(sycl_context); + // Get l0-device + std::vector sycl_devices = sycl_context.get_devices(); + ze_device_handle_t hDev = get_native(sycl_devices[0]); + // Get l0-queue + bool immediate_cmd_list = false; + std::variant queue_var = get_native(sycl_queue); + auto l0_queue = std::get_if(&queue_var); + if (l0_queue == nullptr) { + auto imm_cmd_list = std::get_if(&queue_var); + if (imm_cmd_list == nullptr) { + return false; + } + immediate_cmd_list = true; + sycl_queue_map[sycl_queue].cmd_list = *imm_cmd_list; + } + sycl_queue_map[sycl_queue].context = hCtxt; + sycl_queue_map[sycl_queue].device = hDev; + sycl_queue_map[sycl_queue].queue = immediate_cmd_list ? 0 : *l0_queue; + + // Update global data + context = sycl_queue_map[sycl_queue].context; + uint32_t deviceCount = std::min(sycl_devices.size(), devices.size()); + for (uint32_t i = 0; i < deviceCount; ++i) { + devices[i] = sycl::get_native(sycl_devices[i]); + } + + return true; + } + + static PyObject* initContext(PyObject* self, PyObject* args) { + void* queue; + if(!PyArg_ParseTuple(args, "K", &queue)) + return NULL; + sycl::queue* sycl_queue = static_cast(queue); + if(sycl_queue_map.find(*sycl_queue) == sycl_queue_map.end()) { + update(*sycl_queue); + } + context = sycl_queue_map[*sycl_queue].context; + return Py_BuildValue("(K)", (uint64_t)context); + } + + static PyObject* initEventPool(PyObject* self, PyObject* args) { + // Create event pool + ze_event_pool_desc_t tsEventPoolDesc = { + ZE_STRUCTURE_TYPE_EVENT_POOL_DESC, + nullptr, + ZE_EVENT_POOL_FLAG_HOST_VISIBLE, // all events in pool are visible to Host + 1 // count + }; + ZE_CHECK(zeEventPoolCreate(context, &tsEventPoolDesc, 0, nullptr, &eventPoolHandle)); + + return Py_BuildValue("(K)", (uint64_t)eventPoolHandle); + // Py_RETURN_NONE; + } + + static PyObject* initDevices(PyObject* self, PyObject *args) { + void* queue; + if(!PyArg_ParseTuple(args, "K", &queue)) + return NULL; + sycl::queue* sycl_queue = static_cast(queue); + + auto sycl_context = sycl_queue->get_context(); + + // Get l0-device + std::vector sycl_devices = sycl_context.get_devices(); + + // Retrieve devices + uint32_t deviceCount = sycl_devices.size(); + for (uint32_t i = 0; i < deviceCount; ++i) { + devices.push_back(sycl::get_native(sycl_devices[i])); + } + + // npy_intp dims[1]; + // dims[0] = deviceCount; + // std::cout << "Before PyArray_SimpleNewFromData: " << devices.size() << " " << devices.data()[0] << std::endl; + // PyObject* arr = PyArray_SimpleNewFromData(1, dims, NPY_UINT64, reinterpret_cast(devices.data())); + // std::cout << "After PyArray_SimpleNewFromData: " << devices.data()[0] << std::endl; + // PyObject* ret = Py_BuildValue("(O)", arr); + // std::cout << "After Py_BuildValue" << std::endl; + // return ret; + return Py_BuildValue("(i)", deviceCount); + // Py_RETURN_NONE; + } + + static PyObject* getL0ImmCommandList(PyObject* self, PyObject* args) { + void* queue; + if(!PyArg_ParseTuple(args, "K", &queue)) + return NULL; + sycl::queue* sycl_queue = static_cast(queue); + + if(sycl_queue_map.find(*sycl_queue) == sycl_queue_map.end()) { + update(*sycl_queue); + } + return Py_BuildValue("(K)", (uint64_t)(sycl_queue_map[*sycl_queue].cmd_list)); + } + static PyObject* getL0Queue(PyObject* self, PyObject* args) { + void* queue; + if(!PyArg_ParseTuple(args, "K", &queue)) + return NULL; + sycl::queue* sycl_queue = static_cast(queue); + if(sycl_queue_map.find(*sycl_queue) == sycl_queue_map.end()) { + update(*sycl_queue); + } + return Py_BuildValue("(K)", (uint64_t)(sycl_queue_map[*sycl_queue].queue)); + } + static PyObject* getL0DevPtr(PyObject* self, PyObject* args) { + void* queue; + if(!PyArg_ParseTuple(args, "K", &queue)) + return NULL; + sycl::queue* sycl_queue = static_cast(queue); + if(sycl_queue_map.find(*sycl_queue) == sycl_queue_map.end()) { + update(*sycl_queue); + } + return Py_BuildValue("(K)", (uint64_t)(sycl_queue_map[*sycl_queue].device)); + } + static PyObject* getL0CtxtPtr(PyObject* self, PyObject* args) { + void* queue; + if(!PyArg_ParseTuple(args, "K", &queue)) + return NULL; + sycl::queue* sycl_queue = static_cast(queue); + if(sycl_queue_map.find(*sycl_queue) == sycl_queue_map.end()) { + update(*sycl_queue); + } + return Py_BuildValue("(K)", (uint64_t)(sycl_queue_map[*sycl_queue].context)); + } + static PyObject* isUsingICL(PyObject* self, PyObject* args) { + void* queue; + if(!PyArg_ParseTuple(args, "K", &queue)) + return NULL; + sycl::queue* sycl_queue = static_cast(queue); + if(sycl_queue_map.find(*sycl_queue) == sycl_queue_map.end()) { + update(*sycl_queue); + } + uint32_t using_icl = sycl_queue_map[*sycl_queue].cmd_list != 0 ? 1 : 0; + return Py_BuildValue("(i)", using_icl); + } + + static PyMethodDef ModuleMethods[] = { + {"load_binary", loadBinary, METH_VARARGS, "Load provided SPV into ZE driver"}, + {"get_device_properties", getDeviceProperties, METH_VARARGS, "Get the properties for a given device"}, + {"init_context", initContext, METH_VARARGS, "Initialize the ZE GPU context"}, + {"init_devices", initDevices, METH_VARARGS, "Initialize the ZE GPU devices and return device count"}, + {"init_event_pool", initEventPool, METH_VARARGS, "Initialize ZE event pool"}, + {"get_l0_imm_cmd_list", getL0ImmCommandList, METH_VARARGS, "Get l0 command list in case of immediate command list"}, + {"get_l0_queue", getL0Queue, METH_VARARGS, "Get l0 queue from sycl queue"}, + {"get_l0_dev_ptr", getL0DevPtr, METH_VARARGS, "Extract l0 device pointer from sycl queue"}, + {"get_l0_ctxt_ptr", getL0CtxtPtr, METH_VARARGS, "Extract l0 context pointer from sycl queue"}, + {"is_using_icl", isUsingICL, METH_VARARGS, "Extract sycl queue info, if it is using ICL"}, + {NULL, NULL, 0, NULL} // sentinel + }; + + static struct PyModuleDef ModuleDef = { + PyModuleDef_HEAD_INIT, + "spirv_utils", + NULL, //documentation + -1, //size + ModuleMethods + }; + + PyMODINIT_FUNC PyInit_spirv_utils(void) { + PyObject *m = PyModule_Create(&ModuleDef); + if(m == NULL) { + return NULL; + } + PyModule_AddFunctions(m, ModuleMethods); + return m; + } + """ + + def __init__(self): + mod = compile_module_from_src(self._generate_src(), "spirv_utils") + self.load_binary = mod.load_binary + self.get_device_properties = mod.get_device_properties + self.get_l0_queue = mod.get_l0_queue + self.get_l0_imm_cmd_list = mod.get_l0_imm_cmd_list + self.get_l0_dev_ptr = mod.get_l0_dev_ptr + self.get_l0_ctxt_ptr = mod.get_l0_ctxt_ptr + self.is_using_icl = mod.is_using_icl + self.context = mod.init_context(ipex.xpu.current_stream().sycl_queue) + self.device_count = mod.init_devices(ipex.xpu.current_stream().sycl_queue) + self.event_pool = mod.init_event_pool()[0] + self.current_device = 0 if self.device_count[0] > 0 else -1 + + def get_current_device(instance): + return instance.current_device + + def get_event_pool(instance): + return instance.event_pool + + def set_current_device(instance, idx): + assert instance.device_count[0] > idx, "Device id not found" + instance.current_device = idx + + def get_device_capability(instance, idx): + return (0, 0) + + +# ------------------------ +# Launcher +# ------------------------ + + +def ty_to_cpp(ty): + if ty[0] == '*': + return "void*" + return { + "i1": "int32_t", + "i8": "int8_t", + "i16": "int16_t", + "i32": "int32_t", + "i64": "int64_t", + "u32": "uint32_t", + "u64": "uint64_t", + "fp16": "float", + "bf16": "float", + "fp32": "float", + "f32": "float", + "fp64": "double", + }[ty] + + +def generate_cu_signature(constants, signature, ids): + # CUtensorMap*s are always the last arguments + num_regular_signatures = max(signature.keys()) + 1 if len(signature) > 0 else 0 + if ids["ids_of_tensormaps"] is not None: + for i, _ in enumerate(ids["ids_of_tensormaps"]): + signature[num_regular_signatures + i] = '*CUtensorMap' + return signature, num_regular_signatures + + +def make_launcher(constants, signature, ids): + # Record the end of regular arguments; + # subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA. + signature, desc_start_idx = generate_cu_signature(constants, signature, ids) + arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) + + def _extracted_type(ty): + if ty[0] == '*': + return "void*" + return { + 'i1': 'int32_t', + 'i32': 'int32_t', + 'i64': 'int64_t', + 'u32': 'uint32_t', + 'u64': 'uint64_t', + 'fp16': 'float', + 'bf16': 'float', + 'fp32': 'float', + 'f32': 'float', + 'fp64': 'double', + }[ty] + + def format_of(ty): + return { + "PyObject*": "O", + "void*": "K", + "float": "f", + "double": "d", + "long": "l", + "uint32_t": "I", + "int32_t": "i", + "uint64_t": "K", + "int64_t": "L", + }[ty] + + format = "iiiiiiiiiiKKKKKOOOK" + ''.join( + [format_of(_extracted_type(ty)) for ty in signature.values()]) + + # generate glue code + src = f""" + #include + #include + #include + #include + #include + + #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION + #include + #include + #include + + static inline void gpuAssert(ze_result_t code, const char *file, int line) + {{ + if (code != ZE_RESULT_SUCCESS) + {{ + const char* prefix = "Triton Error [ZE]: "; + std::string str = std::to_string(code); + char err[1024] = {{0}}; + strcat(err, prefix); + strcat(err, str.c_str()); + PyErr_SetString(PyExc_RuntimeError, err); + }} + }} + + #define ZE_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }} + + static void _regular_launch(uint32_t gridX, uint32_t gridY, uint32_t gridZ, int num_warps, int shared_memory, + ze_command_queue_handle_t queue, ze_device_handle_t _dev, ze_context_handle_t _ctxt, + ze_kernel_handle_t function, ze_event_pool_handle_t event_pool + {', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ + void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }}; + + if (gridX*gridY*gridZ > 0) {{ + {" ".join(f'zeKernelSetArgumentValue(function, {idx}, sizeof({ty_to_cpp(item)}), params[{idx}]);' for idx, item in enumerate([signature[i] for i in signature if i not in constants]))} + if (shared_memory) {{ + uint32_t num_params = sizeof(params)/sizeof(params[0]); + zeKernelSetArgumentValue(function, num_params, shared_memory, NULL); + }} + zeKernelSetGroupSize(function, 32*num_warps, 1, 1); + + ze_group_count_t grpCount = {{gridX, gridY, gridZ}}; + + // Create command list + ze_command_list_handle_t CmdList; + ze_command_list_desc_t CommandListDesc_ = {{ + ZE_STRUCTURE_TYPE_COMMAND_LIST_DESC, + nullptr, + 0, + 0, + }}; + + ZE_CHECK(zeCommandListCreate(_ctxt, _dev, &CommandListDesc_, &CmdList)); + + ze_event_desc_t eventDesc = {{ + ZE_STRUCTURE_TYPE_EVENT_DESC, + nullptr, + 0, + 0, + ZE_EVENT_SCOPE_FLAG_HOST + }}; + ze_event_handle_t hEvent; + ZE_CHECK(zeEventCreate(event_pool, &eventDesc, &hEvent)); + + // Append a signal of an event into the command list after the kernel executes + ZE_CHECK(zeCommandListAppendLaunchKernel(CmdList, function, &grpCount, hEvent, 0, nullptr)); + + // close command list + ZE_CHECK(zeCommandListClose(CmdList)); + + // FIXME: The following statement currently doesn't synchronize all IPEX SYCL queues. + // Needs to find all IPEX SYCL queues + // Synchronize the command queue to ensure previous IPEX SYCL commands complete before Triton kernel starts + // ZE_CHECK(zeCommandQueueSynchronize(queue, std::numeric_limits::max())); + + // execute command list + ZE_CHECK(zeCommandQueueExecuteCommandLists(queue, 1, &CmdList, nullptr)); + + // Wait on event to complete + ZE_CHECK(zeEventHostSynchronize(hEvent, std::numeric_limits::max())); + }} + }} + + static void _launch(uint32_t gridX, uint32_t gridY, uint32_t gridZ, int num_warps, int shared_memory, + ze_command_list_handle_t queue, ze_kernel_handle_t function, ze_event_pool_handle_t event_pool + {', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ + void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }}; + + if (gridX*gridY*gridZ > 0) {{ + {" ".join(f'zeKernelSetArgumentValue(function, {idx}, sizeof({ty_to_cpp(item)}), params[{idx}]);' for idx, item in enumerate([signature[i] for i in signature if i not in constants]))} + if (shared_memory) {{ + uint32_t num_params = sizeof(params)/sizeof(params[0]); + zeKernelSetArgumentValue(function, num_params, shared_memory, NULL); + }} + zeKernelSetGroupSize(function, 32*num_warps, 1, 1); + ze_group_count_t grpCount = {{gridX, gridY, gridZ}}; + + ze_event_desc_t eventDesc = {{ + ZE_STRUCTURE_TYPE_EVENT_DESC, + nullptr, + 0, + 0, + ZE_EVENT_SCOPE_FLAG_HOST + }}; + ze_event_handle_t hEvent; + ZE_CHECK(zeEventCreate(event_pool, &eventDesc, &hEvent)); + + // FIXME: The following statement currently doesn't synchronize all IPEX SYCL queues. + // Needs to find all IPEX SYCL queues + // Synchronize to ensure previous IPEX SYCL commands complete before Triton kernel starts + ZE_CHECK(zeCommandListHostSynchronize(queue, std::numeric_limits::max())); + + // Append a signal of an event into the command list after the kernel executes + ZE_CHECK(zeCommandListAppendLaunchKernel(queue, function, &grpCount, hEvent, 0, nullptr)); + // Wait on event to complete + ZE_CHECK(zeEventHostSynchronize(hEvent, std::numeric_limits::max())); + }} + }} + + typedef struct _DevicePtrInfo {{ + void* dev_ptr; + bool valid; + }} DevicePtrInfo; + + static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{ + DevicePtrInfo ptr_info; + ptr_info.dev_ptr = 0; + ptr_info.valid = true; + PyTypeObject* obj_type = Py_TYPE(obj); + + if (PyLong_Check(obj)) {{ + ptr_info.dev_ptr = (void*) PyLong_AsUnsignedLongLong(obj); + return ptr_info; + }} + if (obj == Py_None) {{ + // valid nullptr + return ptr_info; + }} + PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); + if(ptr){{ + PyObject *empty_tuple = PyTuple_New(0); + PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); + Py_DECREF(empty_tuple); + Py_DECREF(ptr); + if (!PyLong_Check(ret)) {{ + PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); + ptr_info.valid = false; + return ptr_info; + }} + ptr_info.dev_ptr = (void*) PyLong_AsUnsignedLongLong(ret); + if(!ptr_info.dev_ptr) {{ + return ptr_info; + }} + Py_DECREF(ret); // Thanks ChatGPT! + return ptr_info; + }} + PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); + return ptr_info; + }} + + static PyObject* launch(PyObject* self, PyObject* args) {{ + + int gridX, gridY, gridZ; + uint64_t _queue; + uint64_t _stream; + uint64_t _function; + uint64_t _event_pool; + uint64_t _dev; + uint64_t _ctxt; + int num_warps; + int num_ctas; + int clusterDimX; + int clusterDimY; + int clusterDimZ; + int _is_icl; + int shared_memory; + PyObject *launch_enter_hook = NULL; + PyObject *launch_exit_hook = NULL; + PyObject *compiled_kernel = NULL; + + + {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} + if (!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &num_ctas, + &clusterDimX, &clusterDimY, &clusterDimZ, &shared_memory, &_is_icl, &_stream, + &_queue, &_dev, &_ctxt, &_function, &launch_enter_hook, &launch_exit_hook, + &compiled_kernel, &_event_pool + {', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''})) {{ + return NULL; + }} + + if (launch_enter_hook != Py_None) {{ + PyObject_CallObject(launch_enter_hook, args); + }} + + // raise exception asap + // {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; + if (_is_icl == 0) {{ + _regular_launch(gridX, gridY, gridZ, num_warps, shared_memory, (ze_command_queue_handle_t)_queue, + (ze_device_handle_t)_dev, (ze_context_handle_t)_ctxt, (ze_kernel_handle_t)_function, + (ze_event_pool_handle_t)_event_pool + {', ' + ', '.join(f"(void *) _arg{i}" if ty[0]=="*" else f"_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''}); + }} else {{ + _launch(gridX, gridY, gridZ, num_warps, shared_memory, (ze_command_list_handle_t)_stream, + (ze_kernel_handle_t)_function, (ze_event_pool_handle_t)_event_pool + {', ' + ', '.join(f"(void *) _arg{i}" if ty[0]=="*" else f"_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''}); + }} + + if (launch_exit_hook != Py_None) {{ + PyObject_CallObject(launch_exit_hook, args); + }} + if (PyErr_Occurred()) {{ + return NULL; + }} + + // return None + Py_INCREF(Py_None); + return Py_None; + }} + + static PyMethodDef ModuleMethods[] = {{ + {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, + {{NULL, NULL, 0, NULL}} // sentinel + }}; + + static struct PyModuleDef ModuleDef = {{ + PyModuleDef_HEAD_INIT, + \"__triton_launcher\", + NULL, //documentation + -1, //size + ModuleMethods + }}; + + PyMODINIT_FUNC PyInit___triton_launcher(void) {{ + PyObject *m = PyModule_Create(&ModuleDef); + if(m == NULL) {{ + return NULL; + }} + PyModule_AddFunctions(m, ModuleMethods); + return m; + }} + """ + return src + + +class XPULauncher(object): + + def __init__(self, src, metadata): + ids = { + "ids_of_tensormaps": metadata.get("ids_of_tensormaps", tuple()), "ids_of_folded_args": + metadata.get("ids_of_folded_args", + tuple()), "ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple() + } + constants = src.constants if hasattr(src, "constants") else dict() + enable_warp_specialization = False + src = make_launcher(constants, src.signature, ids) + mod = compile_module_from_src(src, "__triton_launcher") + self.launch = mod.launch + + def __call__(self, *args, **kwargs): + self.launch(*args, **kwargs) + + +class XPUDriver(DriverBase): + + def __init__(self): + self.utils = XPUUtils() + self.binary_ext = "spv" + self.launcher_cls = XPULauncher + self.get_current_stream = self.get_current_stream + self.get_current_device = self.utils.get_current_device + + def get_current_stream(self, device): + # FIXME + return 0 + + def get_current_target(self): + return ("xpu", 0) + + @staticmethod + def is_active(): + import torch + return torch.xpu.is_available() + + def assemble_tensormap_to_arg(self, tensormaps_info, args): + args_ptr = tuple([arg.data_ptr() if hasattr(arg, 'data_ptr') else arg for arg in args]) + return args_ptr diff --git a/third_party/xpu/backend/lib/libsycl-spir64-unknown-unknown.bc b/third_party/xpu/backend/lib/libsycl-spir64-unknown-unknown.bc new file mode 100644 index 0000000000..69e14200b2 Binary files /dev/null and b/third_party/xpu/backend/lib/libsycl-spir64-unknown-unknown.bc differ diff --git a/third_party/xpu/triton_xpu.cc b/third_party/xpu/triton_xpu.cc new file mode 100644 index 0000000000..3cf0f7da16 --- /dev/null +++ b/third_party/xpu/triton_xpu.cc @@ -0,0 +1,133 @@ +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Target/LLVMIR/Dialect/GENX/GENXToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" +#include "passes.h" +#include "triton/Conversion/NVGPUToLLVM/Passes.h" +#include "triton/Conversion/TritonGPUToLLVM/Passes.h" +#include "triton/Dialect/NVGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" +#include "llvm/IR/Constants.h" +#include "llvm/Support/TargetSelect.h" +#include +#include +#include + +namespace py = pybind11; + +PYBIND11_MAKE_OPAQUE(mlir::triton::gpu::TMAMetadataTy); + +void init_triton_xpu_passes_ttgpuir(py::module &&m) { + using namespace mlir::triton::gpu; + ADD_PASS_WRAPPER_1("add_rewrite_tensor_pointer", + mlir::createTritonGPURewriteTensorPointerPass, int); + // TODO: it is weird to pass mlir::triton::NVVM here since the conversion is + // nvidia-specificontext + m.def("add_to_llvmir", [](mlir::PassManager &pm, int32_t capability, + mlir::triton::gpu::TMAMetadataTy *tmaMetadata) { + pm.addPass(createConvertTritonGPUToLLVMPass(capability, mlir::triton::GENX, + tmaMetadata)); + }); +} + +void init_triton_xpu_passes_ttnvgpuir(py::module &&m) { + ADD_PASS_WRAPPER_1("add_plan_cta", mlir::createTritonNvidiaGPUPlanCTAPass, + mlir::triton::nvidia_gpu::ClusterInfo *); + ADD_PASS_WRAPPER_1("add_wsfeasibility_checking", + mlir::createTritonNvidiaGPUWSFeasibilityCheckingPass, int); + ADD_PASS_WRAPPER_1("add_wsdecomposing", + mlir::createTritonNvidiaGPUWSDecomposingPass, int); + ADD_PASS_WRAPPER_1("add_wsmutex", mlir::createTritonNvidiaGPUWSMutexPass, + int); + ADD_PASS_WRAPPER_1("add_wsmaterialization", + mlir::createTritonNvidiaGPUWSMaterializationPass, int); + ADD_PASS_WRAPPER_0("add_wsfixup_missing_attrs", + mlir::createTritonNvidiaGPUWSFixupMissingAttrs); + ADD_PASS_WRAPPER_2("add_materialize_load_store", + mlir::createTritonNvidiaGPUMaterializeLoadStorePass, int, + int); + ADD_PASS_WRAPPER_0("add_fence_insertion", + mlir::createTritonNvidiaGPUFenceInsertionPass); + ADD_PASS_WRAPPER_0("add_nvgpu_to_llvm", + mlir::triton::createConvertNVGPUToLLVMPass); + ADD_PASS_WRAPPER_3("add_wspipeline", + mlir::createTritonNvidiaGPUWSPipelinePass, int, int, int); + + m.def("is_ws_supported", [](mlir::ModuleOp &mod) -> bool { + return mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect::getWSSupportedAttr( + mod); + }); +} + +void init_triton_xpu(py::module &&m){ + auto passes = m.def_submodule("passes"); + init_triton_xpu_passes_ttgpuir(passes.def_submodule("ttgpuir")); + init_triton_xpu_passes_ttnvgpuir(passes.def_submodule("ttnvgpuir")); + + // cluster info + py::class_(m, "ClusterInfo") + .def(py::init<>()) + .def_readwrite("clusterDimX", + &mlir::triton::nvidia_gpu::ClusterInfo::clusterDimX) + .def_readwrite("clusterDimY", + &mlir::triton::nvidia_gpu::ClusterInfo::clusterDimY) + .def_readwrite("clusterDimZ", + &mlir::triton::nvidia_gpu::ClusterInfo::clusterDimZ) + .def("__repr__", [](mlir::triton::nvidia_gpu::ClusterInfo &self) { + std::ostringstream oss; + oss << "(" << self.clusterDimX << ", " << self.clusterDimY << ", " + << self.clusterDimZ << ")"; + return oss.str(); + }); + + // tma info + py::class_(m, "TMAInfo") + .def(py::init<>()) + .def_readwrite("tensorDataType", + &mlir::triton::gpu::TMAInfo::tensorDataType) + .def_readwrite("tensorRank", &mlir::triton::gpu::TMAInfo::tensorRank) + .def_readwrite("globalAddressArgIdx", + &mlir::triton::gpu::TMAInfo::globalAddressArgIdx) + .def_readwrite("globalStridesArgIdx", + &mlir::triton::gpu::TMAInfo::globalStridesArgIdx) + .def_readwrite("globalDimsArgIdx", + &mlir::triton::gpu::TMAInfo::globalDimsArgIdx) + .def_readwrite("boxDims", &mlir::triton::gpu::TMAInfo::boxDims) + .def_readwrite("elementStrides", + &mlir::triton::gpu::TMAInfo::elementStrides) + .def_readwrite("interleave", &mlir::triton::gpu::TMAInfo::interleave) + .def_readwrite("swizzle", &mlir::triton::gpu::TMAInfo::swizzle) + .def_readwrite("l2Promotion", &mlir::triton::gpu::TMAInfo::l2Promotion) + .def_readwrite("oobFill", &mlir::triton::gpu::TMAInfo::oobFill) + .def_readwrite("TMADescArgIdx", + &mlir::triton::gpu::TMAInfo::TMADescArgIdx); + py::bind_vector>(m, "TMAInfos"); + + // load dialects + m.def("load_dialects", [](mlir::MLIRContext &context) { + mlir::DialectRegistry registry; + registry.insert(); + mlir::registerNVVMDialectTranslation(registry); + mlir::registerGENXDialectTranslation(registry); + context.appendDialectRegistry(registry); + context.loadAllAvailableDialects(); + }); + + // TODO: could be done in python if we had a generic interface to set metadata + m.def("set_nvvm_reflect_ftz", [](llvm::Module *mod) { + // please check https://llvm.org/docs/NVPTXUsage.html#reflection-parameters + // this will enable fast math path in libdevice + // for example, when enable nvvm-reflect-ftz, sqrt.approx.f32 will change to + // sqrt.approx.ftz.f32 + using namespace llvm; + auto &ctx = mod->getContext(); + Type *i32 = Type::getInt32Ty(ctx); + auto *mdFour = ConstantAsMetadata::get(ConstantInt::getSigned(i32, 4)); + auto *mdName = MDString::get(ctx, "nvvm-reflect-ftz"); + auto *mdOne = ConstantAsMetadata::get(ConstantInt::getSigned(i32, 1)); + auto *reflect = MDNode::get(ctx, {mdFour, mdName, mdOne}); + mod->addModuleFlag(reflect); + }); +}