diff --git a/.gitignore b/.gitignore index d533f6f099..8860d6de20 100644 --- a/.gitignore +++ b/.gitignore @@ -6,7 +6,7 @@ python/build/ python/triton.egg-info/ python/triton/_C/libtriton.pyd python/triton/_C/libtriton.so -python/triton/backends/cuda +python/triton/backends/nvidia python/triton/backends/xpu # Python caches diff --git a/.gitmodules b/.gitmodules index e69de29bb2..d964b77672 100644 --- a/.gitmodules +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "third_party/amd"] + path = third_party/amd + url = https://github.com/ptillet/triton.git diff --git a/CMakeLists.txt b/CMakeLists.txt index d9a1586f9f..c8cdb5742e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -19,6 +19,8 @@ if(NOT WIN32) list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") endif() + + # Options option(TRITON_BUILD_TUTORIALS "Build C++ Triton tutorials" ON) option(TRITON_BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF) @@ -82,8 +84,41 @@ include(TableGen) # required by AddMLIR include(AddLLVM) include(AddMLIR) +# Utilities +function(add_triton_object name) + cmake_parse_arguments(ARG "" "" "DEPENDS;LINK_LIBS" ${ARGN}) + add_library(${name} OBJECT) + target_sources(${name} + PRIVATE ${ARG_UNPARSED_ARGUMENTS} + INTERFACE $ + ) + + + # add_library(${name} OBJECT ${ARG_UNPARSED_ARGUMENTS}) + if(ARG_DEPENDS) + add_dependencies(${name} ${ARG_DEPENDS}) + endif() + if(ARG_LINK_LIBS) + target_link_libraries(${name} PUBLIC ${ARG_LINK_LIBS}) + endif() +endfunction(add_triton_object) + +set_property(GLOBAL PROPERTY TRITON_LIBS "") +function(add_triton_library name) + set_property(GLOBAL APPEND PROPERTY TRITON_LIBS ${name}) + add_triton_object(${name} ${ARGN}) + llvm_update_compile_flags(${name}) +endfunction() + +set_property(GLOBAL PROPERTY TRITON_PLUGINS "") +function(add_triton_plugin name) + set_property(GLOBAL APPEND PROPERTY TRITON_PLUGINS ${name}) + add_triton_object(${name} ${ARGN}) +endfunction() + + # Disable warnings that show up in external code (gtest;pybind11) -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default -fvisibility=hidden") include_directories(".") include_directories(${MLIR_INCLUDE_DIRS}) @@ -99,9 +134,6 @@ add_subdirectory(lib) set(TRITON_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}") set(TRITON_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}") -get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) -get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) - # TODO: Figure out which target is sufficient to fix errors; triton is # apparently not enough. Currently set linking libstdc++fs for all targets # to support some old version GCC compilers like 8.3.0. @@ -128,33 +160,26 @@ if(TRITON_BUILD_PYTHON_MODULE) add_link_options(${Python3_LINK_OPTIONS}) endif() - set(TRITON_CODEGEN_BACKENDS "xpu") foreach(CODEGEN_BACKEND ${TRITON_CODEGEN_BACKENDS}) add_subdirectory(third_party/${CODEGEN_BACKEND}) endforeach() + + get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS) + get_property(triton_plugins GLOBAL PROPERTY TRITON_PLUGINS) set(TRITON_LIBRARIES - TritonIR - TritonAnalysis - TritonTransforms - TritonToTritonGPU - TritonGPUIR - TritonGPUTransforms - TritonLLVMIR - TritonNvidiaGPUIR - MLIRAMDGPUDialect - TritonAnalysis - NVGPUToLLVM - TritonNvidiaGPUTransforms - TritonGPUToLLVM + ${triton_libs} + ${triton_plugins} TritonSPIRV + + # mlir + MLIRAMDGPUDialect MLIRNVVMDialect MLIRNVVMToLLVMIRTranslation MLIRGPUToNVVMTransforms MLIRGPUToGPURuntimeTransforms MLIRGPUTransforms - - # optimizations + MLIRIR MLIRControlFlowToLLVM MLIRBytecodeWriter MLIRPass @@ -166,7 +191,9 @@ if(TRITON_BUILD_PYTHON_MODULE) MLIRROCDLToLLVMIRTranslation MLIRGENXToLLVMIRTranslation MLIRGPUDialect - MLIRIR + MLIRSCFToControlFlow + MLIRIndexToLLVM + MLIRGPUToROCDLTransforms # LLVM LLVMPasses @@ -180,12 +207,14 @@ if(TRITON_BUILD_PYTHON_MODULE) ) # Define triton library + string(JOIN "," TRITON_BACKENDS_TUPLE ${TRITON_CODEGEN_BACKENDS}) + set(TRITON_BACKENDS_TUPLE "(${TRITON_BACKENDS_TUPLE})") + add_compile_definitions(TRITON_BACKENDS_TUPLE=${TRITON_BACKENDS_TUPLE}) add_library(triton SHARED ${PYTHON_SRC_PATH}/main.cc ${PYTHON_SRC_PATH}/ir.cc ${PYTHON_SRC_PATH}/passes.cc ${PYTHON_SRC_PATH}/interpreter.cc - ${PYTHON_SRC_PATH}/llvm.cc - ${CMAKE_CURRENT_SOURCE_DIR}/third_party/xpu/triton_xpu.cc) + ${PYTHON_SRC_PATH}/llvm.cc) # Link triton with its dependencies target_link_libraries(triton PUBLIC ${TRITON_LIBRARIES}) @@ -195,7 +224,6 @@ if(TRITON_BUILD_PYTHON_MODULE) target_link_libraries(triton PRIVATE z) endif() target_link_options(triton PRIVATE ${LLVM_LDFLAGS} ${GenISAIntrinsics_LDFLAGS}) - set_target_properties(triton PROPERTIES INTERFACE_LINK_LIBRARIES "") endif() if(UNIX AND NOT APPLE) diff --git a/bin/CMakeLists.txt b/bin/CMakeLists.txt index 91e76b5c38..ba600fa9e1 100644 --- a/bin/CMakeLists.txt +++ b/bin/CMakeLists.txt @@ -12,6 +12,7 @@ target_link_libraries(triton-opt PRIVATE TritonNvidiaGPUTransforms ${dialect_libs} ${conversion_libs} + ${triton_libs} # tests TritonTestAnalysis # MLIR core @@ -33,6 +34,7 @@ target_link_libraries(triton-reduce PRIVATE TritonNvidiaGPUTransforms ${dialect_libs} ${conversion_libs} + ${triton_libs} # tests TritonTestAnalysis # MLIR core diff --git a/include/triton/Target/PTX/TmaMetadata.h b/include/triton/Target/PTX/TmaMetadata.h index eb11a74693..0eb9e14377 100644 --- a/include/triton/Target/PTX/TmaMetadata.h +++ b/include/triton/Target/PTX/TmaMetadata.h @@ -24,7 +24,7 @@ #ifndef TRITON_TARGET_PTX_TMAMETADATA_H #define TRITON_TARGET_PTX_TMAMETADATA_H -#include "third_party/cuda/backend/include/cuda.h" +#include "third_party/nvidia/backend/include/cuda.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Debug.h" #include "llvm/Support/Format.h" diff --git a/lib/Analysis/CMakeLists.txt b/lib/Analysis/CMakeLists.txt index aecc2345ac..a84f0649b6 100644 --- a/lib/Analysis/CMakeLists.txt +++ b/lib/Analysis/CMakeLists.txt @@ -1,4 +1,4 @@ -add_mlir_library(TritonAnalysis +add_triton_library(TritonAnalysis AxisInfo.cpp Allocation.cpp Membar.cpp @@ -10,7 +10,6 @@ add_mlir_library(TritonAnalysis TritonGPUAttrDefsIncGen LINK_LIBS PUBLIC - ASMBuilder MLIRAnalysis MLIRLLVMDialect TritonIR diff --git a/lib/Conversion/NVGPUToLLVM/CMakeLists.txt b/lib/Conversion/NVGPUToLLVM/CMakeLists.txt index 9af2636866..153a9d6de3 100644 --- a/lib/Conversion/NVGPUToLLVM/CMakeLists.txt +++ b/lib/Conversion/NVGPUToLLVM/CMakeLists.txt @@ -1,16 +1,9 @@ -add_mlir_conversion_library(NVGPUToLLVM +add_triton_library(NVGPUToLLVM NVGPUToLLVMPass.cpp - ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/include/triton/Conversion/NVGPUToLLVM - ${PROJECT_BINARY_DIR}/include/triton/Conversion/NVGPUToLLVM - DEPENDS NVGPUConversionPassIncGen - LINK_COMPONENTS - Core - LINK_LIBS PUBLIC MLIRIR MLIRPass diff --git a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt index 76a2e3438f..d6349d86a9 100644 --- a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt +++ b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -1,10 +1,4 @@ -# Separate out PTX/GCN builders to avoid cyclic dependencies as TritonAnalysis -# depends on it. -set(LLVM_OPTIONAL_SOURCES - PTXAsmFormat.cpp - ) - -add_mlir_conversion_library(TritonGPUToLLVM +add_triton_library(TritonGPUToLLVM ConvertLayoutOpToLLVM/SharedToDotOperandDPAS.cpp ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp ConvertLayoutOpToLLVM/SharedToDotOperandMMAv1.cpp @@ -16,8 +10,8 @@ add_mlir_conversion_library(TritonGPUToLLVM DotOpToLLVM/MMAv2.cpp DotOpToLLVM/WGMMA.cpp DotOpToLLVM.cpp - ElementwiseOpToLLVM.cpp HistogramOpToLLVM.cpp + ElementwiseOpToLLVM.cpp LoadStoreOpToLLVM.cpp BarrierOpToLLVM.cpp TritonGPUToLLVM.cpp @@ -30,19 +24,12 @@ add_mlir_conversion_library(TritonGPUToLLVM TensorPtrOpsToLLVM.cpp ClusterOpsToLLVM.cpp RegReallocOpToLLVM.cpp - - ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/include/triton/Conversion/TritonGPUToLLVM - ${PROJECT_BINARY_DIR}/include/triton/Conversion/TritonGPUToLLVM + PTXAsmFormat.cpp DEPENDS TritonGPUConversionPassIncGen - LINK_COMPONENTS - Core - LINK_LIBS PUBLIC - ASMBuilder MLIRIR MLIRPass MLIRGENXDialect @@ -58,14 +45,3 @@ add_mlir_conversion_library(TritonGPUToLLVM TritonNvidiaGPUTransforms NVGPUIR ) - -add_mlir_library(ASMBuilder - PTXAsmFormat.cpp - - DEPENDS - TritonTableGen - - LINK_LIBS PUBLIC - MLIRAnalysis - MLIRLLVMDialect -) diff --git a/lib/Conversion/TritonToTritonGPU/CMakeLists.txt b/lib/Conversion/TritonToTritonGPU/CMakeLists.txt index 834d10a4de..d770aeb22c 100644 --- a/lib/Conversion/TritonToTritonGPU/CMakeLists.txt +++ b/lib/Conversion/TritonToTritonGPU/CMakeLists.txt @@ -1,16 +1,9 @@ -add_mlir_conversion_library(TritonToTritonGPU +add_triton_library(TritonToTritonGPU TritonToTritonGPUPass.cpp - ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/include/triton/Conversion/TritonToTritonGPU - ${PROJECT_BINARY_DIR}/include/triton/Conversion/TritonToTritonGPU - DEPENDS TritonConversionPassIncGen - LINK_COMPONENTS - Core - LINK_LIBS PUBLIC MLIRIR MLIRPass diff --git a/lib/Dialect/NVGPU/IR/CMakeLists.txt b/lib/Dialect/NVGPU/IR/CMakeLists.txt index 24a93ce58e..1fd118d2be 100644 --- a/lib/Dialect/NVGPU/IR/CMakeLists.txt +++ b/lib/Dialect/NVGPU/IR/CMakeLists.txt @@ -1,4 +1,4 @@ -add_mlir_dialect_library(NVGPUIR +add_triton_library(NVGPUIR Dialect.cpp DEPENDS diff --git a/lib/Dialect/Triton/IR/CMakeLists.txt b/lib/Dialect/Triton/IR/CMakeLists.txt index 6ee110718c..71165b17fc 100644 --- a/lib/Dialect/Triton/IR/CMakeLists.txt +++ b/lib/Dialect/Triton/IR/CMakeLists.txt @@ -1,4 +1,4 @@ -add_mlir_dialect_library(TritonIR +add_triton_library(TritonIR Dialect.cpp Ops.cpp Types.cpp diff --git a/lib/Dialect/Triton/Transforms/CMakeLists.txt b/lib/Dialect/Triton/Transforms/CMakeLists.txt index d06c01566c..2983987506 100644 --- a/lib/Dialect/Triton/Transforms/CMakeLists.txt +++ b/lib/Dialect/Triton/Transforms/CMakeLists.txt @@ -2,7 +2,7 @@ set(LLVM_TARGET_DEFINITIONS Combine.td) mlir_tablegen(TritonCombine.inc -gen-rewriters) add_public_tablegen_target(TritonCombineIncGen) -add_mlir_dialect_library(TritonTransforms +add_triton_library(TritonTransforms Combine.cpp ReorderBroadcast.cpp RewriteTensorPointer.cpp diff --git a/lib/Dialect/TritonGPU/IR/CMakeLists.txt b/lib/Dialect/TritonGPU/IR/CMakeLists.txt index bab4a5dec4..82cf23f052 100644 --- a/lib/Dialect/TritonGPU/IR/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/IR/CMakeLists.txt @@ -1,4 +1,4 @@ -add_mlir_dialect_library(TritonGPUIR +add_triton_library(TritonGPUIR Dialect.cpp Traits.cpp Types.cpp diff --git a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt index 2cd6e26725..c93feb8151 100644 --- a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -1,4 +1,4 @@ -add_mlir_dialect_library(TritonGPUTransforms +add_triton_library(TritonGPUTransforms AccelerateMatmul.cpp Coalesce.cpp DecomposeConversions.cpp diff --git a/lib/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt b/lib/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt index 99f2ef6b70..4369542327 100644 --- a/lib/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt +++ b/lib/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt @@ -1,4 +1,4 @@ -add_mlir_dialect_library(TritonNvidiaGPUIR +add_triton_library(TritonNvidiaGPUIR Dialect.cpp Ops.cpp Traits.cpp diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt b/lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt index 53674ebfc6..a147b7e996 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt @@ -1,4 +1,4 @@ -add_mlir_dialect_library(TritonNvidiaGPUTransforms +add_triton_library(TritonNvidiaGPUTransforms MaterializeLoadStore.cpp PlanCTA.cpp WSDecomposing.cpp diff --git a/lib/Target/LLVMIR/CMakeLists.txt b/lib/Target/LLVMIR/CMakeLists.txt index 510cfab9c8..f2f9adf8f4 100644 --- a/lib/Target/LLVMIR/CMakeLists.txt +++ b/lib/Target/LLVMIR/CMakeLists.txt @@ -1,10 +1,7 @@ -add_mlir_translation_library(TritonLLVMIR +add_triton_library(TritonLLVMIR LLVMDIScope.cpp LLVMIRBreakPhiStruct.cpp - LINK_COMPONENTS - Core - DEPENDS LLVMIRIncGen diff --git a/python/setup.py b/python/setup.py index a23cf2b11b..79390f2918 100644 --- a/python/setup.py +++ b/python/setup.py @@ -15,6 +15,45 @@ from setuptools import Extension, setup from setuptools.command.build_ext import build_ext from setuptools.command.build_py import build_py +from dataclasses import dataclass + + +@dataclass +class Backend: + name: str + package_data: dict + src_dir: str + + +def _copy_backends(active): + ret = [] + root_dir = os.path.join(os.pardir, "third_party") + for backend in active: + curr_path = os.path.join(root_dir, backend) + backend_path = os.path.join(curr_path, "backend") + # check conditions + assert backend in os.listdir(root_dir), f"{backend} is requested for install but not present in {root_dir}" + assert os.listdir(curr_path), f"{curr_path} is empty!" + assert os.path.exists(backend_path), f"{backend_path} does not exist!" + for file in ["compiler.py", "driver.py"]: + assert os.path.exists(os.path.join(backend_path, file)) + # initialize submodule if there is one + try: + subprocess.run(["git", "submodule", "update", "--init", f"{backend}"], check=True, + stdout=subprocess.DEVNULL, cwd=root_dir) + except subprocess.CalledProcessError: + pass + except FileNotFoundError: + pass + # copy backend over + dst_path = os.path.join(os.path.dirname(__file__), "triton", "backends", backend) + if os.path.exists(dst_path): + shutil.rmtree(dst_path) + shutil.copytree(backend_path, dst_path) + # update + package_data = [f"{os.path.relpath(p, backend_path)}/*" for p, _, _, in os.walk(backend_path)] + ret.append(Backend(name=backend, package_data=package_data, src_dir=curr_path)) + return ret # Taken from https://github.com/pytorch/pytorch/blob/master/tools/setup_helpers/env.py @@ -106,6 +145,9 @@ def open_url(url): return urllib.request.urlopen(request) +# ---- package data --- + + def get_thirdparty_packages(triton_cache_path): packages = [get_pybind11_package_info(), get_llvm_package_info()] thirdparty_cmake_args = [] @@ -135,9 +177,6 @@ def get_thirdparty_packages(triton_cache_path): return thirdparty_cmake_args -# ---- package data --- - - def download_and_copy(src_path, variable, version, url_func): if variable in os.environ: return @@ -146,7 +185,7 @@ def download_and_copy(src_path, variable, version, url_func): if arch == "x86_64": arch = "64" url = url_func(arch, version) - dst_path = os.path.join(base_dir, os.pardir, "third_party", "cuda", "backend", src_path) + dst_path = os.path.join(base_dir, os.pardir, "third_party", "nvidia", "backend", src_path) is_linux = platform.system() == "Linux" download = False if is_linux: @@ -249,20 +288,15 @@ def build_extension(self, ext): # python directories python_include_dir = sysconfig.get_path("platinclude") cmake_args = [ - "-G", - "Ninja", # Ninja is much faster than make + "-G", "Ninja", # Ninja is much faster than make "-DCMAKE_MAKE_PROGRAM=" + ninja_dir, # Pass explicit path to ninja otherwise cmake may cache a temporary path - "-DCMAKE_EXPORT_COMPILE_COMMANDS=ON", - "-DLLVM_ENABLE_WERROR=ON", - "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir, "-DLLVM_SPIRV_DIR=" + llvm_spirv_path, - "-DTRITON_BUILD_TUTORIALS=OFF", - "-DTRITON_BUILD_PYTHON_MODULE=ON", - "-DPython3_EXECUTABLE:FILEPATH=" + sys.executable, - "-DCMAKE_EXPORT_COMPILE_COMMANDS:BOOL=ON", - "-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON", - "-DPYTHON_INCLUDE_DIRS=" + python_include_dir, + "-DCMAKE_EXPORT_COMPILE_COMMANDS=ON", "-DLLVM_ENABLE_WERROR=ON", + "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir, "-DTRITON_BUILD_TUTORIALS=OFF", + "-DTRITON_BUILD_PYTHON_MODULE=ON", "-DPython3_EXECUTABLE:FILEPATH=" + sys.executable, + "-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON", "-DPYTHON_INCLUDE_DIRS=" + python_include_dir, + "-DTRITON_CODEGEN_BACKENDS=" + ';'.join([b.name for b in backends]) ] if lit_dir is not None: cmake_args.append("-DLLVM_EXTERNAL_LIT=" + lit_dir) @@ -345,14 +379,11 @@ def build_extension(self, ext): url_func=lambda arch, version: f"https://anaconda.org/nvidia/cuda-nvdisasm/12.3.52/download/linux-{arch}/cuda-nvdisasm-{version}-0.tar.bz2", ) +backends = _copy_backends(["xpu"]) -plugins = ["xpu"] -for plugin in plugins: - src_path = os.path.join(os.pardir, "third_party", plugin, "backend") - dst_path = os.path.join(os.path.dirname(__file__), "triton", "backends", plugin) - if os.path.exists(dst_path): - shutil.rmtree(dst_path) - shutil.copytree(src_path, dst_path) +package_data = dict() +package_data["triton/tools"] = ["compile.h", "compile.c"] +package_data.update({f"triton/backends/{b.name}": b.package_data for b in backends}) setup( name=os.environ.get("TRITON_WHEEL_NAME", "triton"), @@ -371,14 +402,10 @@ def build_extension(self, ext): "triton/ops/blocksparse", "triton/runtime", "triton/backends", - "triton/backends/xpu", "triton/tools", - ], + ] + [f'triton/backends/{backend.name}' for backend in backends], install_requires=["filelock"], - package_data={ - "triton/tools": ["compile.h", "compile.c"], - "triton/backends/xpu": ["bin/*", "lib/*", "include/*"], - }, + package_data=package_data, include_package_data=True, ext_modules=[CMakeExtension("triton", "triton/_C/")], cmdclass={"build_ext": CMakeBuild, "build_py": CMakeBuildPy, "clean": CMakeClean}, diff --git a/python/src/main.cc b/python/src/main.cc index 36f5b2679f..5ad4be7d55 100644 --- a/python/src/main.cc +++ b/python/src/main.cc @@ -1,12 +1,43 @@ #include namespace py = pybind11; +#define FOR_EACH_1(MACRO, X) MACRO(X) +#define FOR_EACH_2(MACRO, X, ...) MACRO(X) FOR_EACH_1(MACRO, __VA_ARGS__) +#define FOR_EACH_3(MACRO, X, ...) MACRO(X) FOR_EACH_2(MACRO, __VA_ARGS__) +#define FOR_EACH_4(MACRO, X, ...) MACRO(X) FOR_EACH_3(MACRO, __VA_ARGS__) + +#define FOR_EACH_NARG(...) FOR_EACH_NARG_(__VA_ARGS__, FOR_EACH_RSEQ_N()) +#define FOR_EACH_NARG_(...) FOR_EACH_ARG_N(__VA_ARGS__) +#define FOR_EACH_ARG_N(_1, _2, _3, _4, N, ...) N +#define FOR_EACH_RSEQ_N() 4, 3, 2, 1, 0 + +#define CONCATENATE(x, y) CONCATENATE1(x, y) +#define CONCATENATE1(x, y) x##y + +#define FOR_EACH(MACRO, ...) \ + CONCATENATE(FOR_EACH_, FOR_EACH_NARG_HELPER(__VA_ARGS__))(MACRO, __VA_ARGS__) +#define FOR_EACH_NARG_HELPER(...) FOR_EACH_NARG(__VA_ARGS__) + +// New macro to remove parentheses +#define REMOVE_PARENS(...) __VA_ARGS__ + +// Intermediate macro to ensure correct expansion +#define FOR_EACH_P_INTERMEDIATE(MACRO, ...) FOR_EACH(MACRO, __VA_ARGS__) + +// Modified FOR_EACH to handle parentheses +#define FOR_EACH_P(MACRO, ARGS_WITH_PARENS) \ + FOR_EACH_P_INTERMEDIATE(MACRO, REMOVE_PARENS ARGS_WITH_PARENS) + +#define DECLARE_BACKEND(name) void init_triton_##name(pybind11::module &&m); + +#define INIT_BACKEND(name) init_triton_##name(m.def_submodule(#name)); + void init_triton_env_vars(pybind11::module &m); void init_triton_ir(pybind11::module &&m); void init_triton_llvm(pybind11::module &&m); void init_triton_interpreter(pybind11::module &&m); void init_triton_passes(pybind11::module &&m); -void init_triton_xpu(pybind11::module &&m); +FOR_EACH_P(DECLARE_BACKEND, TRITON_BACKENDS_TUPLE) PYBIND11_MODULE(libtriton, m) { m.doc() = "Python bindings to the C++ Triton API"; @@ -15,5 +46,5 @@ PYBIND11_MODULE(libtriton, m) { init_triton_passes(m.def_submodule("passes")); init_triton_interpreter(m.def_submodule("interpreter")); init_triton_llvm(m.def_submodule("llvm")); - init_triton_xpu(m.def_submodule("xpu")); + FOR_EACH_P(INIT_BACKEND, TRITON_BACKENDS_TUPLE) } diff --git a/python/test/unit/tools/test_aot.py b/python/test/unit/tools/test_aot.py index 5ec600e92e..4565696575 100644 --- a/python/test/unit/tools/test_aot.py +++ b/python/test/unit/tools/test_aot.py @@ -7,7 +7,7 @@ import numpy as np import triton -from triton.backends.cuda.driver import include_dir, library_dir +from triton.backends.nvidia.driver import include_dir, library_dir kernel_utils_src = """ import triton diff --git a/python/triton/backends/__init__.py b/python/triton/backends/__init__.py index cabcec71b8..fbf65d9e90 100644 --- a/python/triton/backends/__init__.py +++ b/python/triton/backends/__init__.py @@ -1,5 +1,5 @@ import os -import importlib +import importlib.util import inspect from dataclasses import dataclass from .driver import DriverBase diff --git a/test/lib/Analysis/CMakeLists.txt b/test/lib/Analysis/CMakeLists.txt index e3c7743795..1bf9d84702 100644 --- a/test/lib/Analysis/CMakeLists.txt +++ b/test/lib/Analysis/CMakeLists.txt @@ -6,6 +6,5 @@ add_mlir_library(TritonTestAnalysis LINK_LIBS PUBLIC MLIRPass - TritonAnalysis - ${dialect_libs} + ${triton_libs} ) diff --git a/third_party/.diff.swp b/third_party/.diff.swp new file mode 100644 index 0000000000..6605263e5f Binary files /dev/null and b/third_party/.diff.swp differ diff --git a/third_party/amd b/third_party/amd new file mode 160000 index 0000000000..a3c7061800 --- /dev/null +++ b/third_party/amd @@ -0,0 +1 @@ +Subproject commit a3c7061800f31db179ba34e1369725841ec8cb0d diff --git a/third_party/cuda/CMakeLists.txt b/third_party/cuda/CMakeLists.txt deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/third_party/nvidia/CMakeLists.txt b/third_party/nvidia/CMakeLists.txt new file mode 100644 index 0000000000..f6a91676b5 --- /dev/null +++ b/third_party/nvidia/CMakeLists.txt @@ -0,0 +1 @@ +add_triton_plugin(TritonNVIDIA ${CMAKE_CURRENT_SOURCE_DIR}/triton_nvidia.cc) \ No newline at end of file diff --git a/third_party/cuda/backend/__init__.py b/third_party/nvidia/backend/__init__.py similarity index 100% rename from third_party/cuda/backend/__init__.py rename to third_party/nvidia/backend/__init__.py diff --git a/third_party/cuda/backend/compiler.py b/third_party/nvidia/backend/compiler.py similarity index 100% rename from third_party/cuda/backend/compiler.py rename to third_party/nvidia/backend/compiler.py diff --git a/third_party/cuda/backend/driver.c b/third_party/nvidia/backend/driver.c similarity index 100% rename from third_party/cuda/backend/driver.c rename to third_party/nvidia/backend/driver.c diff --git a/third_party/cuda/backend/driver.py b/third_party/nvidia/backend/driver.py similarity index 100% rename from third_party/cuda/backend/driver.py rename to third_party/nvidia/backend/driver.py diff --git a/third_party/cuda/backend/include/cuda.h b/third_party/nvidia/backend/include/cuda.h similarity index 100% rename from third_party/cuda/backend/include/cuda.h rename to third_party/nvidia/backend/include/cuda.h diff --git a/third_party/cuda/backend/lib/libdevice.10.bc b/third_party/nvidia/backend/lib/libdevice.10.bc similarity index 100% rename from third_party/cuda/backend/lib/libdevice.10.bc rename to third_party/nvidia/backend/lib/libdevice.10.bc diff --git a/third_party/cuda/triton_nvidia.cc b/third_party/nvidia/triton_nvidia.cc similarity index 100% rename from third_party/cuda/triton_nvidia.cc rename to third_party/nvidia/triton_nvidia.cc diff --git a/third_party/xpu/CMakeLists.txt b/third_party/xpu/CMakeLists.txt index e69de29bb2..72ab20fac6 100644 --- a/third_party/xpu/CMakeLists.txt +++ b/third_party/xpu/CMakeLists.txt @@ -0,0 +1 @@ +add_triton_plugin(TritonXPU ${CMAKE_CURRENT_SOURCE_DIR}/triton_xpu.cc) \ No newline at end of file diff --git a/unittest/Analysis/CMakeLists.txt b/unittest/Analysis/CMakeLists.txt index af11f1d807..e94696bf5f 100644 --- a/unittest/Analysis/CMakeLists.txt +++ b/unittest/Analysis/CMakeLists.txt @@ -6,4 +6,5 @@ add_triton_ut( TritonIR TritonGPUIR ${dialect_libs} + ${triton_libs} ) diff --git a/unittest/Conversion/TritonGPUToLLVM/CMakeLists.txt b/unittest/Conversion/TritonGPUToLLVM/CMakeLists.txt index 0ba2be07f2..592d1b7c23 100644 --- a/unittest/Conversion/TritonGPUToLLVM/CMakeLists.txt +++ b/unittest/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -7,5 +7,5 @@ add_triton_ut( add_triton_ut( NAME TestEmitIndices SRCS EmitIndicesTest.cpp DumpLayout.cpp - LIBS TritonGPUIR TritonNvidiaGPUIR ${dialect_libs} ${conversion_libs} + LIBS TritonGPUIR TritonNvidiaGPUIR ${dialect_libs} ${conversion_libs} ${triton_libs} ) diff --git a/unittest/Dialect/TritonGPU/CMakeLists.txt b/unittest/Dialect/TritonGPU/CMakeLists.txt index 3dfa69701f..28576d7fd4 100644 --- a/unittest/Dialect/TritonGPU/CMakeLists.txt +++ b/unittest/Dialect/TritonGPU/CMakeLists.txt @@ -1,5 +1,5 @@ add_triton_ut( NAME TestSwizzling SRCS SwizzleTest.cpp - LIBS TritonGPUIR TritonNvidiaGPUIR TritonTransforms ${dialect_libs} ${conversion_libs} + LIBS TritonGPUIR TritonNvidiaGPUIR TritonTransforms ${dialect_libs} ${conversion_libs} ${triton_libs} )