Skip to content

Commit

Permalink
[FRONTEND] use standard plugin interface for CUDA (#2887)
Browse files Browse the repository at this point in the history
  • Loading branch information
ptillet authored Jan 8, 2024
1 parent 8441be4 commit 4803403
Show file tree
Hide file tree
Showing 47 changed files with 1,680 additions and 2,057 deletions.
10 changes: 0 additions & 10 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,10 +0,0 @@
[submodule "third_party/intel_xpu_backend"]
path = third_party/intel_xpu_backend
url = http://github.com/intel/intel-xpu-backend-for-triton
[submodule "third_party/amd_hip_backend"]
path = third_party/amd_hip_backend
url = https://github.com/ROCmSoftwarePlatform/triton
branch = third_party_backend_2
[submodule "third_party/triton_shared"]
path = third_party/triton_shared
url = https://github.com/microsoft/triton-shared
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ repos:
exclude: |
(?x)(
^include/triton/external/|
^python/triton/third_party/
^third_party/
)
exclude: |
(?x)(
^include/triton/external/|
^python/triton/third_party/
^third_party/
)
216 changes: 65 additions & 151 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,141 +43,20 @@ endif()

# Compiler flags
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17")

# Third-party
include_directories(${PYBIND11_INCLUDE_DIR})

set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17 -fvisibility=hidden -fvisibility-inlines-hidden")

if(APPLE)
set(CMAKE_OSX_DEPLOYMENT_TARGET 11.6)
endif()

# #########
# LLVM
# #########
if(NOT MLIR_DIR)
if(NOT LLVM_LIBRARY_DIR)
if(WIN32)
find_package(LLVM 13 REQUIRED COMPONENTS nvptx amdgpu)

include_directories(${LLVM_INCLUDE_DIRS})
separate_arguments(LLVM_DEFINITIONS_LIST NATIVE_COMMAND ${LLVM_DEFINITIONS})
add_definitions(${LLVM_DEFINITIONS_LIST})

llvm_map_components_to_libnames(LLVM_LIBRARIES support core
NVPTXInfo nvptxcodegen
AMDGPUInfo AMDGPUcodegen
)
else()
find_package(LLVM 11 REQUIRED COMPONENTS "nvptx;amdgpu")
endif()

message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}")

# FindLLVM outputs LLVM_LIBRARY_DIRS but we expect LLVM_LIBRARY_DIR here
set(LLVM_LIBRARY_DIR ${LLVM_LIBRARY_DIRS})

if(APPLE)
set(CMAKE_OSX_DEPLOYMENT_TARGET "10.14")
endif()

# sometimes we don't want to use llvm-config, since it may have been downloaded for some specific linux distros
else()
set(LLVM_LDFLAGS "-L${LLVM_LIBRARY_DIR}")
set(LLVM_LIBRARIES
LLVMNVPTXCodeGen
LLVMNVPTXDesc
LLVMNVPTXInfo
LLVMAMDGPUDisassembler
LLVMMCDisassembler
LLVMAMDGPUCodeGen
LLVMMIRParser
LLVMGlobalISel
LLVMSelectionDAG
LLVMipo
LLVMInstrumentation
LLVMVectorize
LLVMLinker
LLVMIRReader
LLVMAsmParser
LLVMFrontendOpenMP
LLVMAsmPrinter
LLVMDebugInfoDWARF
LLVMCodeGen
LLVMTarget
LLVMScalarOpts
LLVMInstCombine
LLVMAggressiveInstCombine
LLVMTransformUtils
LLVMBitWriter
LLVMAnalysis
LLVMProfileData
LLVMObject
LLVMTextAPI
LLVMBitReader
LLVMAMDGPUAsmParser
LLVMMCParser
LLVMAMDGPUDesc
LLVMAMDGPUUtils
LLVMMC
LLVMDebugInfoCodeView
LLVMDebugInfoMSF
LLVMCore
LLVMRemarks
LLVMBitstreamReader
LLVMBinaryFormat
LLVMAMDGPUInfo
LLVMSupport
LLVMDemangle
LLVMPasses
LLVMAnalysis
LLVMTransformUtils
LLVMScalarOpts
LLVMTransformUtils
LLVMipo
LLVMObjCARCOpts
LLVMCoroutines
LLVMAnalysis
)
endif()

set(MLIR_DIR ${LLVM_LIBRARY_DIR}/cmake/mlir)
endif()

# Python module
if(TRITON_BUILD_PYTHON_MODULE)
message(STATUS "Adding Python module")
set(PYTHON_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/src)
set(PYTHON_SRC ${PYTHON_SRC_PATH}/main.cc
${PYTHON_SRC_PATH}/ir.cc
${PYTHON_SRC_PATH}/passes.cc
${PYTHON_SRC_PATH}/interpreter.cc
${PYTHON_SRC_PATH}/llvm.cc
${PYTHON_SRC_PATH}/nvidia.cc)
include_directories("." ${PYTHON_SRC_PATH})

if(PYTHON_INCLUDE_DIRS)
include_directories(${PYTHON_INCLUDE_DIRS})
else()
find_package(Python3 REQUIRED COMPONENTS Development Interpreter)
include_directories(${Python3_INCLUDE_DIRS})
link_directories(${Python3_LIBRARY_DIRS})
link_libraries(${Python3_LIBRARIES})
add_link_options(${Python3_LINK_OPTIONS})
endif()
endif()

# # Triton
# file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc)
# if (WIN32 AND TRITON_BUILD_PYTHON_MODULE)
# Python3_add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
# set_target_properties(triton PROPERTIES SUFFIX ".pyd")
# set_target_properties(triton PROPERTIES PREFIX "lib")
# else()
# add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
# endif()

# MLIR
find_package(MLIR REQUIRED CONFIG PATHS ${MLIR_DIR})

Expand All @@ -191,6 +70,7 @@ include(AddMLIR)
# Disable warnings that show up in external code (gtest;pybind11)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default")

include_directories(".")
include_directories(${MLIR_INCLUDE_DIRS})
include_directories(${LLVM_INCLUDE_DIRS})
include_directories(${PROJECT_SOURCE_DIR}/include)
Expand All @@ -214,49 +94,91 @@ if (NOT WIN32 AND NOT APPLE)
link_libraries(stdc++fs)
endif()


# -----

# ------
if(TRITON_BUILD_PYTHON_MODULE)
add_library(triton SHARED ${PYTHON_SRC})
message(STATUS "Adding Python module")
set(PYTHON_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/src)
include_directories(${PYTHON_SRC_PATH})

if(PYTHON_INCLUDE_DIRS)
include_directories(${PYTHON_INCLUDE_DIRS})
else()
find_package(Python3 REQUIRED COMPONENTS Development Interpreter)
include_directories(${Python3_INCLUDE_DIRS})
link_directories(${Python3_LIBRARY_DIRS})
link_libraries(${Python3_LIBRARIES})
add_link_options(${Python3_LINK_OPTIONS})
endif()

set(TRITON_CODEGEN_BACKENDS "cuda")
foreach(CODEGEN_BACKEND ${TRITON_CODEGEN_BACKENDS})
add_subdirectory(third_party/${CODEGEN_BACKEND})
endforeach()

set(TRITON_LIBRARIES
TritonIR
TritonAnalysis
TritonTransforms
TritonToTritonGPU
TritonGPUIR
TritonGPUTransforms
TritonNvidiaGPUTransforms
TritonLLVMIR
${dialect_libs}
${conversion_libs}
TritonNvidiaGPUIR
MLIRAMDGPUDialect
TritonAnalysis
NVGPUToLLVM
TritonNvidiaGPUTransforms
TritonGPUToLLVM
MLIRNVVMDialect
MLIRNVVMToLLVMIRTranslation
MLIRGPUToNVVMTransforms
MLIRGPUToGPURuntimeTransforms
MLIRGPUTransforms

# optimizations
MLIRControlFlowToLLVM
MLIRBytecodeWriter
MLIRPass
MLIRTransforms
MLIRLLVMDialect
MLIRSupport
MLIRTargetLLVMIRExport
MLIRMathToLLVM
MLIRNVVMToLLVMIRTranslation
MLIRROCDLToLLVMIRTranslation
MLIRGPUDialect
MLIRIR

# LLVM
LLVMPasses
LLVMX86CodeGen
LLVMX86AsmParser
LLVMNVPTXCodeGen
# LLVMNVPTXAsmPrinter
LLVMAMDGPUCodeGen
LLVMAMDGPUAsmParser

)

# Define triton library
add_library(triton SHARED ${PYTHON_SRC_PATH}/main.cc
${PYTHON_SRC_PATH}/ir.cc
${PYTHON_SRC_PATH}/passes.cc
${PYTHON_SRC_PATH}/interpreter.cc
${PYTHON_SRC_PATH}/llvm.cc
${CMAKE_CURRENT_SOURCE_DIR}/third_party/cuda/triton_nvidia.cc)

# Link triton with its dependencies
target_link_libraries(triton PUBLIC ${TRITON_LIBRARIES})
if(WIN32)
target_link_libraries(triton PRIVATE ${LLVM_LIBRARIES} ${CMAKE_DL_LIBS}
${TRITON_LIBRARIES}
)
elseif(APPLE)
target_link_libraries(triton ${LLVM_LIBRARIES} z
${TRITON_LIBRARIES}
)
target_link_libraries(triton PRIVATE ${CMAKE_DL_LIBS})
else()
target_link_libraries(triton ${LLVM_LIBRARIES} z
${TRITON_LIBRARIES}
)
target_link_libraries(triton PRIVATE z)
endif()

target_link_options(triton PRIVATE ${LLVM_LDFLAGS})
endif()

if(UNIX AND NOT APPLE)
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,--exclude-libs,ALL")
set_target_properties(triton PROPERTIES INTERFACE_LINK_LIBRARIES "")
endif()

if(TRITON_BUILD_PYTHON_MODULE AND NOT WIN32)
Expand All @@ -267,15 +189,7 @@ if(TRITON_BUILD_PYTHON_MODULE AND NOT WIN32)
set(PYTHON_LDFLAGS "-undefined dynamic_lookup -flto")
endif()

target_link_libraries(triton ${CUTLASS_LIBRARIES} ${PYTHON_LDFLAGS})
endif()

list(LENGTH TRITON_CODEGEN_BACKENDS CODEGEN_BACKENDS_LEN)
if (${CODEGEN_BACKENDS_LEN} GREATER 0)
set(PYTHON_THIRD_PARTY_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/triton/third_party)
foreach(CODEGEN_BACKEND ${TRITON_CODEGEN_BACKENDS})
add_subdirectory(third_party/${CODEGEN_BACKEND})
endforeach()
target_link_libraries(triton ${PYTHON_LDFLAGS})
endif()

add_subdirectory(bin)
Expand Down
2 changes: 1 addition & 1 deletion include/triton/Target/PTX/TmaMetadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#ifndef TRITON_TARGET_PTX_TMAMETADATA_H
#define TRITON_TARGET_PTX_TMAMETADATA_H

#include "python/triton/third_party/cuda/include/cuda.h"
#include "third_party/cuda/backend/include/cuda.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/Format.h"
Expand Down
13 changes: 0 additions & 13 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -557,19 +557,6 @@ bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
!srcTy.getElementType().isF32();
}

bool isSingleValue(Value value) {
// Don't consider load as expensive if it is loading a scalar.
if (auto tensorTy = value.getType().dyn_cast<RankedTensorType>())
return tensorTy.getNumElements() == 1;
// TODO: Handle other cases.
// For example, when ptr is a tensor of single value.
// It means that ptr is a resultant of broadcast or generated through
// a chain of broadcast and other operations.
// Rematerialize it without considering contiguous memory access pattern is
// fine.
return true;
}

namespace {

/// A data structure similar to SetVector but maintains
Expand Down
13 changes: 13 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,19 @@ std::optional<Attribute> inferDstEncoding(Operation *op, Attribute encoding) {
return encoding;
}

bool isSingleValue(Value value) {
// Don't consider load as expensive if it is loading a scalar.
if (auto tensorTy = value.getType().dyn_cast<RankedTensorType>())
return tensorTy.getNumElements() == 1;
// TODO: Handle other cases.
// For example, when ptr is a tensor of single value.
// It means that ptr is a resultant of broadcast or generated through
// a chain of broadcast and other operations.
// Rematerialize it without considering contiguous memory access pattern is
// fine.
return true;
}

bool isExpensiveLoadOrStore(Operation *op) {
// Case 1: Pointer of tensor is always expensive
auto operandType = op->getOperand(0).getType();
Expand Down
5 changes: 0 additions & 5 deletions python/MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,5 +0,0 @@
graft src
graft triton/third_party
graft triton/tools
graft triton/runtime/backends/
graft triton/language/extra
Loading

0 comments on commit 4803403

Please sign in to comment.