Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge OpenAI Triton commit 4dac289 #265

Merged
merged 13 commits into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions .github/workflows/test-backends.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,6 @@ jobs:
python3 setup.py build
python3 -m pip install --no-build-isolation -vvv '.[tests]'

- name: Run shared middle-layer lit tests
run: |
python3 -m pip install lit
cd python
LIT_TEST_DIR="build/$(ls build | grep -i cmake)/third_party/triton_shared/test"
if [ ! -d "${LIT_TEST_DIR}" ]; then
echo "Coult not find '${LIT_TEST_DIR}'" ; exit -1
fi
lit -v "${LIT_TEST_DIR}"


Integration-Tests-AMD:
needs: Runner-Preparation
Expand Down
11 changes: 10 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,15 @@ python/build/
python/triton.egg-info/
python/triton/_C/libtriton.pyd
python/triton/_C/libtriton.so
python/triton/backends/cuda
python/triton/backends/nvidia
python/triton/backends/xpu

# Backends copied from submodules
python/triton/backends/
!python/triton/backends/__init__.py
!python/triton/backends/compiler.py
!python/triton/backends/driver.py

# Python caches
__pycache__/
*.py[cod]
Expand Down Expand Up @@ -46,3 +52,6 @@ docs/getting-started/tutorials
/compile_commands.json
.vscode
.vs

# Vim
*.swp
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "third_party/amd"]
path = third_party/amd
url = https://github.com/ptillet/triton.git
78 changes: 53 additions & 25 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ if(NOT WIN32)
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
endif()



# Options
option(TRITON_BUILD_TUTORIALS "Build C++ Triton tutorials" ON)
option(TRITON_BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF)
Expand Down Expand Up @@ -82,8 +84,41 @@ include(TableGen) # required by AddMLIR
include(AddLLVM)
include(AddMLIR)

# Utilities
function(add_triton_object name)
cmake_parse_arguments(ARG "" "" "DEPENDS;LINK_LIBS" ${ARGN})
add_library(${name} OBJECT)
target_sources(${name}
PRIVATE ${ARG_UNPARSED_ARGUMENTS}
INTERFACE $<TARGET_OBJECTS:${name}>
)


# add_library(${name} OBJECT ${ARG_UNPARSED_ARGUMENTS})
if(ARG_DEPENDS)
add_dependencies(${name} ${ARG_DEPENDS})
endif()
if(ARG_LINK_LIBS)
target_link_libraries(${name} PUBLIC ${ARG_LINK_LIBS})
endif()
endfunction(add_triton_object)

set_property(GLOBAL PROPERTY TRITON_LIBS "")
function(add_triton_library name)
set_property(GLOBAL APPEND PROPERTY TRITON_LIBS ${name})
add_triton_object(${name} ${ARGN})
llvm_update_compile_flags(${name})
endfunction()

set_property(GLOBAL PROPERTY TRITON_PLUGINS "")
function(add_triton_plugin name)
set_property(GLOBAL APPEND PROPERTY TRITON_PLUGINS ${name})
add_triton_object(${name} ${ARGN})
endfunction()


# Disable warnings that show up in external code (gtest;pybind11)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default -fvisibility=hidden")

include_directories(".")
include_directories(${MLIR_INCLUDE_DIRS})
Expand All @@ -99,9 +134,6 @@ add_subdirectory(lib)
set(TRITON_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
set(TRITON_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}")

get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)

# TODO: Figure out which target is sufficient to fix errors; triton is
# apparently not enough. Currently set linking libstdc++fs for all targets
# to support some old version GCC compilers like 8.3.0.
Expand All @@ -128,33 +160,26 @@ if(TRITON_BUILD_PYTHON_MODULE)
add_link_options(${Python3_LINK_OPTIONS})
endif()

set(TRITON_CODEGEN_BACKENDS "xpu")
foreach(CODEGEN_BACKEND ${TRITON_CODEGEN_BACKENDS})
add_subdirectory(third_party/${CODEGEN_BACKEND})
endforeach()


get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS)
get_property(triton_plugins GLOBAL PROPERTY TRITON_PLUGINS)
set(TRITON_LIBRARIES
TritonIR
TritonAnalysis
TritonTransforms
TritonToTritonGPU
TritonGPUIR
TritonGPUTransforms
TritonLLVMIR
TritonNvidiaGPUIR
MLIRAMDGPUDialect
TritonAnalysis
NVGPUToLLVM
TritonNvidiaGPUTransforms
TritonGPUToLLVM
${triton_libs}
${triton_plugins}
TritonSPIRV

# mlir
MLIRAMDGPUDialect
MLIRNVVMDialect
MLIRNVVMToLLVMIRTranslation
MLIRGPUToNVVMTransforms
MLIRGPUToGPURuntimeTransforms
MLIRGPUTransforms

# optimizations
MLIRIR
MLIRControlFlowToLLVM
MLIRBytecodeWriter
MLIRPass
Expand All @@ -166,7 +191,9 @@ if(TRITON_BUILD_PYTHON_MODULE)
MLIRROCDLToLLVMIRTranslation
MLIRGENXToLLVMIRTranslation
MLIRGPUDialect
MLIRIR
MLIRSCFToControlFlow
MLIRIndexToLLVM
MLIRGPUToROCDLTransforms

# LLVM
LLVMPasses
Expand All @@ -180,12 +207,14 @@ if(TRITON_BUILD_PYTHON_MODULE)
)

# Define triton library
string(JOIN "," TRITON_BACKENDS_TUPLE ${TRITON_CODEGEN_BACKENDS})
set(TRITON_BACKENDS_TUPLE "(${TRITON_BACKENDS_TUPLE})")
add_compile_definitions(TRITON_BACKENDS_TUPLE=${TRITON_BACKENDS_TUPLE})
add_library(triton SHARED ${PYTHON_SRC_PATH}/main.cc
${PYTHON_SRC_PATH}/ir.cc
${PYTHON_SRC_PATH}/passes.cc
${PYTHON_SRC_PATH}/interpreter.cc
${PYTHON_SRC_PATH}/llvm.cc
${CMAKE_CURRENT_SOURCE_DIR}/third_party/xpu/triton_xpu.cc)
${PYTHON_SRC_PATH}/llvm.cc)

# Link triton with its dependencies
target_link_libraries(triton PUBLIC ${TRITON_LIBRARIES})
Expand All @@ -195,7 +224,6 @@ if(TRITON_BUILD_PYTHON_MODULE)
target_link_libraries(triton PRIVATE z)
endif()
target_link_options(triton PRIVATE ${LLVM_LDFLAGS} ${GenISAIntrinsics_LDFLAGS})
set_target_properties(triton PROPERTIES INTERFACE_LINK_LIBRARIES "")
endif()

if(UNIX AND NOT APPLE)
Expand All @@ -210,7 +238,7 @@ if(TRITON_BUILD_PYTHON_MODULE AND NOT WIN32)
set(PYTHON_LDFLAGS "-undefined dynamic_lookup -flto")
endif()

target_link_libraries(triton ${PYTHON_LDFLAGS})
target_link_libraries(triton PRIVATE ${PYTHON_LDFLAGS})
endif()

add_subdirectory(bin)
Expand Down
25 changes: 25 additions & 0 deletions bin/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ target_link_libraries(triton-opt PRIVATE
TritonNvidiaGPUTransforms
${dialect_libs}
${conversion_libs}
${triton_libs}
# tests
TritonTestAnalysis
# MLIR core
Expand All @@ -33,6 +34,7 @@ target_link_libraries(triton-reduce PRIVATE
TritonNvidiaGPUTransforms
${dialect_libs}
${conversion_libs}
${triton_libs}
# tests
TritonTestAnalysis
# MLIR core
Expand All @@ -43,6 +45,29 @@ target_link_libraries(triton-reduce PRIVATE

mlir_check_all_link_libraries(triton-reduce)

add_llvm_executable(triton-lsp triton-lsp.cpp PARTIAL_SOURCES_INTENDED)
mlir_check_all_link_libraries(triton-lsp)

llvm_update_compile_flags(triton-lsp)
target_link_libraries(triton-lsp PRIVATE
TritonAnalysis
TritonTransforms
TritonGPUTransforms
TritonNvidiaGPUTransforms
${dialect_libs}
${conversion_libs}
${triton_libs}
# tests
TritonTestAnalysis
# MLIR core
MLIRLspServerLib
MLIRPass
MLIRTransforms
)

mlir_check_all_link_libraries(triton-lsp)


add_llvm_executable(triton-llvm-opt
triton-llvm-opt.cpp

Expand Down
11 changes: 11 additions & 0 deletions bin/triton-lsp.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#include "./RegisterTritonDialects.h"

#include "mlir/Tools/mlir-lsp-server/MlirLspServerMain.h"

int main(int argc, char **argv) {
mlir::DialectRegistry registry;
registerTritonDialects(registry);

mlir::MLIRContext context(registry);
return mlir::failed(mlir::MlirLspServerMain(argc, argv, registry));
}
9 changes: 9 additions & 0 deletions include/triton/Dialect/Triton/IR/TritonAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -97,5 +97,14 @@ def TT_RoundingModeAttr : I32EnumAttr<
let cppNamespace = "::mlir::triton";
}

// PropagateNan.
def TT_PropagateNanAttr : I32EnumAttr<
"PropagateNan", "",
[
I32EnumAttrCase<"NONE", 0, "none">,
I32EnumAttrCase<"ALL", 0xFFFF, "all">,
]> {
let cppNamespace = "::mlir::triton";
}

#endif
18 changes: 18 additions & 0 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,24 @@ def TT_FpToFpOp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
let hasVerifier = 1;
}

def TT_ClampFOp : TT_Op<"clampf", [Elementwise,
SameOperandsAndResultType,
Pure]> {
let summary = "Clamp operation for floating point types";

let description = [{
Clamp operation for floating point types.

The operation takes three arguments: x, min, and max. It returns a tensor of the same shape as x with its values clamped to the range [min, max].
}];

let arguments = (ins TT_FloatLike:$x, TT_FloatLike:$min, TT_FloatLike:$max, TT_PropagateNanAttr:$propagateNan);

let results = (outs TT_FloatLike:$result);

let assemblyFormat = "$x `,` $min `,` $max attr-dict `:` type($result)";
}

//
// Pointer Arith Ops
//
Expand Down
2 changes: 1 addition & 1 deletion include/triton/Target/PTX/TmaMetadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#ifndef TRITON_TARGET_PTX_TMAMETADATA_H
#define TRITON_TARGET_PTX_TMAMETADATA_H

#include "third_party/cuda/backend/include/cuda.h"
#include "third_party/nvidia/backend/include/cuda.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/Format.h"
Expand Down
3 changes: 1 addition & 2 deletions lib/Analysis/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
add_mlir_library(TritonAnalysis
add_triton_library(TritonAnalysis
AxisInfo.cpp
Allocation.cpp
Membar.cpp
Expand All @@ -10,7 +10,6 @@ add_mlir_library(TritonAnalysis
TritonGPUAttrDefsIncGen

LINK_LIBS PUBLIC
ASMBuilder
MLIRAnalysis
MLIRLLVMDialect
TritonIR
Expand Down
9 changes: 1 addition & 8 deletions lib/Conversion/NVGPUToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@
add_mlir_conversion_library(NVGPUToLLVM
add_triton_library(NVGPUToLLVM
NVGPUToLLVMPass.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/triton/Conversion/NVGPUToLLVM
${PROJECT_BINARY_DIR}/include/triton/Conversion/NVGPUToLLVM

DEPENDS
NVGPUConversionPassIncGen

LINK_COMPONENTS
Core

LINK_LIBS PUBLIC
MLIRIR
MLIRPass
Expand Down
30 changes: 3 additions & 27 deletions lib/Conversion/TritonGPUToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,4 @@
# Separate out PTX/GCN builders to avoid cyclic dependencies as TritonAnalysis
# depends on it.
set(LLVM_OPTIONAL_SOURCES
PTXAsmFormat.cpp
)

add_mlir_conversion_library(TritonGPUToLLVM
add_triton_library(TritonGPUToLLVM
ConvertLayoutOpToLLVM/SharedToDotOperandDPAS.cpp
ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp
ConvertLayoutOpToLLVM/SharedToDotOperandMMAv1.cpp
Expand All @@ -16,8 +10,8 @@ add_mlir_conversion_library(TritonGPUToLLVM
DotOpToLLVM/MMAv2.cpp
DotOpToLLVM/WGMMA.cpp
DotOpToLLVM.cpp
ElementwiseOpToLLVM.cpp
HistogramOpToLLVM.cpp
ElementwiseOpToLLVM.cpp
LoadStoreOpToLLVM.cpp
BarrierOpToLLVM.cpp
TritonGPUToLLVM.cpp
Expand All @@ -30,19 +24,12 @@ add_mlir_conversion_library(TritonGPUToLLVM
TensorPtrOpsToLLVM.cpp
ClusterOpsToLLVM.cpp
RegReallocOpToLLVM.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/triton/Conversion/TritonGPUToLLVM
${PROJECT_BINARY_DIR}/include/triton/Conversion/TritonGPUToLLVM
PTXAsmFormat.cpp

DEPENDS
TritonGPUConversionPassIncGen

LINK_COMPONENTS
Core

LINK_LIBS PUBLIC
ASMBuilder
MLIRIR
MLIRPass
MLIRGENXDialect
Expand All @@ -58,14 +45,3 @@ add_mlir_conversion_library(TritonGPUToLLVM
TritonNvidiaGPUTransforms
NVGPUIR
)

add_mlir_library(ASMBuilder
PTXAsmFormat.cpp

DEPENDS
TritonTableGen

LINK_LIBS PUBLIC
MLIRAnalysis
MLIRLLVMDialect
)
Loading
Loading