Skip to content

Commit

Permalink
Merge commit '51dddd36d2a094a3503c8f4562cad37ed6f96c08'
Browse files Browse the repository at this point in the history
  • Loading branch information
whitneywhtsang committed Jan 9, 2025
2 parents bcd932c + 51dddd3 commit ab58512
Show file tree
Hide file tree
Showing 16 changed files with 292 additions and 219 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ For detailed instructions on how to debug Triton's frontend, please refer to thi
- `MLIR_ENABLE_DUMP=1` dumps the IR before every MLIR pass Triton runs, for all
kernels. Use `MLIR_ENABLE_DUMP=kernelName` to dump for a specific kernel only.
- Triton cache can interfere with the dump. In cases where `MLIR_ENABLE_DUMP=1` does not work, try cleaning your triton cache: `rm -r ~/.triton/cache/*`
- `MLIR_DUMP_PATH` specifies where `MLIR_ENABLE_DUMP` will dump to. If unset will dump to stderr.
- `LLVM_IR_ENABLE_DUMP=1` dumps the IR before every pass run over the LLVM IR.
- `TRITON_REPRODUCER_PATH=<reproducer_path>` will generate an MLIR reproducer file
at `<reproducer_path>` before each MLIR compiler stage. If any of the stages fail,
Expand Down
1 change: 1 addition & 0 deletions docs/python-api/triton.language.rst
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ Compiler Hint Ops
:toctree: generated
:nosignatures:

assume
debug_barrier
max_constancy
max_contiguous
Expand Down
1 change: 1 addition & 0 deletions include/triton/Tools/Sys/GetEnv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
"LLVM_PASS_PLUGIN_PATH",
"MLIR_ENABLE_DIAGNOSTICS",
"MLIR_ENABLE_DUMP",
"MLIR_DUMP_PATH",
"MLIR_ENABLE_TIMING",
"TRITON_DEFAULT_FP_FUSION",
"TRITON_DISABLE_LINE_INFO",
Expand Down
19 changes: 18 additions & 1 deletion python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Tools/Sys/GetEnv.hpp"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/SourceMgr.h"

#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h"
Expand All @@ -39,6 +40,22 @@ namespace py = pybind11;
using namespace mlir;
using namespace triton;

llvm::raw_fd_ostream &mlir_dumps() {
std::error_code EC;
static llvm::raw_fd_ostream S(::triton::tools::getStrEnv("MLIR_DUMP_PATH"),
EC, llvm::sys::fs::CD_CreateAlways);
assert(!EC);
return S;
}

llvm::raw_ostream &mlir_dumps_or_dbgs() {
if (!::triton::tools::getStrEnv("MLIR_DUMP_PATH").empty()) {
return mlir_dumps();
} else {
return llvm::dbgs();
}
}

// A custom op builder that keeps track of the last location
class TritonOpBuilder {
public:
Expand Down Expand Up @@ -1711,7 +1728,7 @@ void init_triton_ir(py::module &&m) {
/*shouldPrintAfterPass=*/printAlways,
/*printModuleScope=*/true,
/*printAfterOnlyOnChange=*/false,
/*printAfterOnlyOnFailure*/ true, llvm::dbgs(),
/*printAfterOnlyOnFailure*/ true, mlir_dumps_or_dbgs(),
printingFlags);
}
})
Expand Down
36 changes: 1 addition & 35 deletions python/test/backend/test_device_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,11 @@
import tempfile
from pathlib import Path

import setuptools
import torch

import triton
import triton.language as tl
from triton.common.backend import (BaseBackend, compute_core_version_key, register_backend)
from triton.common.build import quiet
from triton.compiler.make_launcher import make_so_cache_key
from triton.runtime.cache import get_cache_manager
from triton.runtime.driver import DriverBase
Expand Down Expand Up @@ -43,39 +41,7 @@ def build_for_backend(name, src, srcdir):
scheme = 'posix_prefix'
py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]

ret = subprocess.check_call([cc, src, f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-o", so])
if ret == 0:
return so
# fallback on setuptools
extra_compile_args = []
library_dirs = []
include_dirs = [srcdir]
libraries = []
# extra arguments
extra_link_args = []
# create extension module
ext = setuptools.Extension(
name=name,
language='c',
sources=[src],
include_dirs=include_dirs,
extra_compile_args=extra_compile_args + ['-O3'],
extra_link_args=extra_link_args,
library_dirs=library_dirs,
libraries=libraries,
)
# build extension module
args = ['build_ext']
args.append('--build-temp=' + srcdir)
args.append('--build-lib=' + srcdir)
args.append('-q')
args = dict(
name=name,
ext_modules=[ext],
script_args=args,
)
with quiet():
setuptools.setup(**args)
subprocess.check_call([cc, src, f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-o", so])
return so


Expand Down
118 changes: 85 additions & 33 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3284,42 +3284,94 @@ def convert_fp8_to_fp32(x, device, dtype_str):
assert "Unsupported float8 dtype"


# M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size
def get_test_dot_base_cases():
return [(*shape, 4, False, False, epilogue, input_precision, in_dtype, out_dtype, 1, None)
for shape in [(64, 64, 64), (32, 32, 32), (16, 16, 16)]
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot']
for input_precision in ['tf32', 'tf32x3', 'ieee']
for in_dtype, out_dtype in [('float16', 'float16'), ('float16', 'float32'), ('float32', 'float32')]
if not (input_precision != 'ieee' and (in_dtype in ['float16']))]


# M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size
def get_test_dot_mixed_sizes_cases():
available_kpack = [1, 2 if is_hip() else 1]
available_precision = ["tf32" if is_cuda() or is_xpu() else "ieee"]
return [
(*shape_nw, col_a, col_b, 'none', input_precision, in_dtype, out_dtype, kpack, None)
for shape_nw in [[128, 256, 32, 8], [128, 16, 32, 4], [32, 128, 64, 4], [128, 128, 64, 4], [64, 128, 128, 4],
[32, 128, 64, 2], [64, 64, 32, 4], [32, 32, 128, 16], [128, 128, 64, 2], [64, 128, 128, 2]]
for input_precision in available_precision
for col_a in [True, False]
for col_b in [True, False]
for in_dtype, out_dtype in [('int8', 'int8'), ('float16', 'float16'), ('float16',
'float32'), ('float32', 'float32')]
for kpack in available_kpack
]


# M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size
# introduced in #2370
def get_test_dot_transposed_op_base_cases():
return [(64, 64, 64, 4, col_a, col_b, 'none', 'ieee', 'float32', 'float32', 1, None)
for col_a in [True, False]
for col_b in [True, False]]


# M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size
# Introduced in #2750
def get_test_dot_h100_shortcut_cases():
return [(64, 64, 64, 4, False, False, 'chain-dot', 'ieee', 'bfloat16', 'float32', 1, None)]


# M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size
# introduced in #3908
def get_test_dot_mfma_edge_cases():
if not is_hip_cdna():
return []
return [(16, 16, 8, 4, False, False, 'None', 'ieee', 'float32', 'float32', 1, None),
(32, 16, 8, 4, False, False, 'None', 'ieee', 'float16', 'float16', 1, None)]


# M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size
# introduced in #3370
def get_test_dot_fp8_output_cases():
return [(128, 128, 64, 4, False, False, 'chain-dot', 'ieee', float8_type, 'float32', 1, None)
for float8_type in ["float8e5", "float8e4nv"]]


# M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size
# introduced in #5406
def get_test_dot_small_k_mfma_cases():
if not is_hip_cdna():
return []
return [(32, 32, k_size, 4, False, False, 'None', 'ieee', in_dtype, out_dtype, 1, mma_nonk_size)
for k_size in [1, 2, 4, 8]
for in_dtype, out_dtype in [('float16', 'float32'), ('int8', 'int32')]
for mma_nonk_size in mma_nonk_sizes]


# M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size
# introduced in #4516
def get_test_dot_small_mn_fma_cases():
return [(*shape_nw, False, False, epilogue, 'ieee', in_dtype, out_dtype, 1, None)
for shape_nw in [(2, 2, 16, 1), (1, 64, 64, 1), (64, 2, 64, 2), (64, 64, 4, 4), (8, 16, 16, 1)]
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols']
for in_dtype, out_dtype in [('float16', 'float16'), ('float32', 'float32')]]


@pytest.mark.interpreter
@pytest.mark.parametrize(
"M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size",
[(*shape, 4, False, False, epilogue, input_precision, in_dtype, out_dtype, 1, None)
for shape in [(64, 64, 64), (32, 32, 32), (16, 16, 16)]
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot']
for input_precision in ['tf32', 'tf32x3', 'ieee']
for in_dtype, out_dtype in [('float16', 'float16'), ('float16', 'float32'), ('float32', 'float32')]
if not (input_precision != 'ieee' and (in_dtype in ['float16']))] +
[(*shape_nw, col_a, col_b, 'none', input_precision, in_dtype, out_dtype, kpack, None)
for shape_nw in [[128, 256, 32, 8], [128, 16, 32, 4], [32, 128, 64, 4], [128, 128, 64, 4], [64, 128, 128, 4],
[32, 128, 64, 2], [64, 64, 32, 4], [32, 32, 128, 16], [128, 128, 64, 2], [64, 128, 128, 2]]
for input_precision in ["tf32" if is_cuda() or is_xpu() else "ieee"]
for col_a in [True, False]
for col_b in [True, False]
for in_dtype, out_dtype in [('int8', 'int8'), ('float16', 'float16'), ('float16', 'float32'), ('float32',
'float32')]
for kpack in [1, 2 if is_hip() else 1]] +
[(64, 64, 64, 4, col_a, col_b, 'none', 'ieee', 'float32', 'float32', 1, None)
for col_a in [True, False]
for col_b in [True, False]] +
[(64, 64, 64, 4, False, False, 'chain-dot', 'ieee', 'bfloat16', 'float32', 1, None)] +
([(16, 16, 8, 4, False, False, 'None', 'ieee', 'float32', 'float32', 1, None),
(32, 16, 8, 4, False, False, 'None', 'ieee', 'float16', 'float16', 1, None)] if "gfx9" in get_arch() else []) +
[(128, 128, 64, 4, False, False, 'chain-dot', 'ieee', float8_type, 'float32', 1, None)
for float8_type in ["float8e5", "float8e4nv"]] +
# small k cases for MFMA dots
[(32, 32, k_size, 4, False, False, 'None', 'ieee', in_dtype, out_dtype, 1, mma_nonk_size)
for k_size in [1, 2, 4, 8]
for in_dtype, out_dtype in [('float16', 'float32'), ('int8', 'int32')]
for mma_nonk_size in mma_nonk_sizes] +
# small m/n cases for FMA dots
[(*shape_nw, False, False, epilogue, 'ieee', in_dtype, out_dtype, 1, None)
for shape_nw in [(2, 2, 16, 1), (1, 64, 64, 1), (64, 2, 64, 2), (64, 64, 4, 4), (8, 16, 16, 1)]
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols']
for in_dtype, out_dtype in [('float16', 'float16'), ('float32', 'float32')]])
get_test_dot_base_cases() + \
get_test_dot_mixed_sizes_cases() + \
get_test_dot_transposed_op_base_cases() + \
get_test_dot_h100_shortcut_cases() + \
get_test_dot_mfma_edge_cases() + \
get_test_dot_fp8_output_cases() + \
get_test_dot_small_k_mfma_cases() + \
get_test_dot_small_mn_fma_cases())
@pytest.mark.parametrize("num_ctas", num_ctas_list)
def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size,
num_ctas, device):
Expand Down
43 changes: 1 addition & 42 deletions python/triton/runtime/build.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,14 @@
import contextlib
import sys
import io
import sysconfig
import os
import shutil
import subprocess
import setuptools


def is_xpu():
import torch
return torch.xpu.is_available()


@contextlib.contextmanager
def quiet():
old_stdout, old_stderr = sys.stdout, sys.stderr
sys.stdout, sys.stderr = io.StringIO(), io.StringIO()
try:
yield
finally:
sys.stdout, sys.stderr = old_stdout, old_stderr


def _cc_cmd(cc, src, out, include_dirs, library_dirs, libraries):
if cc in ["cl", "clang-cl"]:
cc_cmd = [cc, src, "/nologo", "/O2", "/LD"]
Expand Down Expand Up @@ -103,32 +89,5 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries, extra_compi
if os.getenv("VERBOSE"):
print(" ".join(cc_cmd))

ret = subprocess.check_call(cc_cmd, stdout=subprocess.DEVNULL)
if ret == 0:
return so
# extra arguments
extra_link_args = []
# create extension module
ext = setuptools.Extension(
name=name,
language='c',
sources=[src],
include_dirs=include_dirs,
extra_compile_args=extra_compile_args + ['-O3' if os.name != "nt" else "/O2"],
extra_link_args=extra_link_args,
library_dirs=library_dirs,
libraries=libraries,
)
# build extension module
args = ['build_ext']
args.append('--build-temp=' + srcdir)
args.append('--build-lib=' + srcdir)
args.append('-q')
args = dict(
name=name,
ext_modules=[ext],
script_args=args,
)
with quiet():
setuptools.setup(**args)
subprocess.check_call(cc_cmd, stdout=subprocess.DEVNULL)
return so
24 changes: 24 additions & 0 deletions test/Conversion/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,30 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {

// -----

#blocked0 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#shared0 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>
#mma0 = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}>
#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma0, kWidth=2}>
#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma0, kWidth=2}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_dot_mmav3_shared
tt.func @convert_dot_mmav3_shared(%A: tensor<64x64xf16, #blocked0>, %B: tensor<64x64xf16, #blocked0>) {
%AA = ttg.local_alloc %A : (tensor<64x64xf16, #blocked0>) -> !ttg.memdesc<64x64xf16, #shared0, #smem>
%BB = ttg.local_alloc %B : (tensor<64x64xf16, #blocked0>) -> !ttg.memdesc<64x64xf16, #shared0, #smem>
// CHECK-NOT: nvgpu.ldmatrix
%AA_DOT = ttg.local_load %AA : !ttg.memdesc<64x64xf16, #shared0, #smem> -> tensor<64x64xf16, #dot_operand_a>
%BB_DOT = ttg.local_load %BB : !ttg.memdesc<64x64xf16, #shared0, #smem> -> tensor<64x64xf16, #dot_operand_b>
%cst0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma0>

%D = tt.dot %AA_DOT, %BB_DOT, %cst0 : tensor<64x64xf16, #dot_operand_a> * tensor<64x64xf16, #dot_operand_b> -> tensor<64x64xf32, #mma0>

tt.return
}
}

// -----

#blocked0 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#shared0 = #ttg.shared<{vec = 1, perPhase=2, maxPhase=8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#mma0 = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}>
Expand Down
Loading

0 comments on commit ab58512

Please sign in to comment.