Skip to content

Commit

Permalink
Port and run runtime tests
Browse files Browse the repository at this point in the history
Signed-off-by: Tsang, Whitney <[email protected]>
  • Loading branch information
whitneywhtsang committed Jan 12, 2024
1 parent a47163d commit 4c3a0d1
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 29 deletions.
9 changes: 5 additions & 4 deletions .github/workflows/build_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,12 @@ jobs:
- name: Run core tests
if: ${{ env.BACKEND == 'XPU'}}
run: |
cd python/test/unit/language
python3 -m pytest --verbose --device xpu --ignore=test_line_info.py --ignore=test_subprocess.py
cd python/test/unit
python3 -m pytest -n 8 --verbose --device xpu language/ --ignore=language/test_line_info.py --ignore=language/test_subprocess.py
# run runtime tests serially to avoid race condition with cache handling.
python3 -m pytest runtime/
# run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0
TRITON_DISABLE_LINE_INFO=0 python3 -m pytest --verbose --device xpu test_line_info.py
TRITON_DISABLE_LINE_INFO=0 python3 -m pytest --verbose --device xpu language/test_line_info.py
- name: Run assert/print tests
if: ${{ env.BACKEND == 'XPU'}}
Expand Down
8 changes: 5 additions & 3 deletions .github/workflows/build_and_test_2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,12 @@ jobs:
run: |
pip install pytest pytest-xdist
pip install torch==1.13.0a0+git6c9b55e intel_extension_for_pytorch==1.13.120+xpu -f https://developer.intel.com/ipex-whl-stable-xpu
cd python/test/unit/language
python3 -m pytest -n auto --verbose --device xpu --ignore=test_line_info.py --ignore=test_subprocess.py
cd python/test/unit
python3 -m pytest -n 8 --verbose --device xpu language/ --ignore=language/test_line_info.py --ignore=language/test_subprocess.py
# run runtime tests serially to avoid race condition with cache handling.
python3 -m pytest runtime/
# run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0
TRITON_DISABLE_LINE_INFO=0 python3 -m pytest -n auto test_line_info.py
TRITON_DISABLE_LINE_INFO=0 python3 -m pytest --verbose --device xpu language/test_line_info.py
- name: Run assert/print tests
run: |
Expand Down
13 changes: 8 additions & 5 deletions python/test/unit/runtime/test_autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
import triton.language as tl
import pytest

# FIXME remove this once Triton L0 queue and IPEX SYCL queue can be synchronized through events
torch.xpu.enable_sync_mode()


def test_kwargs():
N = 1024
src = torch.empty(N, device='cuda')
dst = torch.empty(N, device='cuda')
src = torch.empty(N, device='xpu')
dst = torch.empty(N, device='xpu')

configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})]

Expand All @@ -26,7 +29,7 @@ def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr):

def test_restore():
N = 1024
src = torch.zeros(N, device='cuda')
src = torch.zeros(N, device='xpu')

configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})]

Expand All @@ -45,8 +48,8 @@ def _kernel(src, N, BLOCK_SIZE: tl.constexpr):
@pytest.mark.parametrize('with_perf_model', [False, True])
def test_prune_configs(with_perf_model: bool):
N = 1024
src = torch.empty(N, device='cuda')
dst = torch.empty(N, device='cuda')
src = torch.empty(N, device='xpu')
dst = torch.empty(N, device='xpu')
records = {}

def early_config_prune(configs, named_args):
Expand Down
27 changes: 16 additions & 11 deletions python/test/unit/runtime/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
import triton.language as tl
from triton.runtime.jit import JITFunction

# FIXME remove this once Triton L0 queue and IPEX SYCL queue can be synchronized through events
torch.xpu.enable_sync_mode()

tmpdir = ".tmp"


Expand Down Expand Up @@ -111,7 +114,7 @@ def inc_counter(*args, **kwargs):

JITFunction.cache_hook = inc_counter
reset_tmp_dir()
x = torch.empty(1, dtype=torch.int32, device='cuda')
x = torch.empty(1, dtype=torch.int32, device='xpu')
for i in range(10):
kernel[(1, )](x, 1, BLOCK=1024)
assert counter == 1
Expand All @@ -127,7 +130,7 @@ def inc_counter(*args, **kwargs):

JITFunction.cache_hook = inc_counter
reset_tmp_dir()
x = torch.empty(1, dtype=torch.int32, device='cuda')
x = torch.empty(1, dtype=torch.int32, device='xpu')
function = {'enable': kernel, 'disable': kernel_nospec}[mode]
target = {'enable': 4, 'disable': 1}[mode]
for i in [1, 2, 4, 8, 16, 32]:
Expand All @@ -141,9 +144,9 @@ def test_annotation():
def kernel(X, i: tl.int32):
tl.store(X, i)

x = torch.empty(1, dtype=torch.int32, device='cuda')
x = torch.empty(1, dtype=torch.int32, device='xpu')

device = torch.cuda.current_device()
device = torch.xpu.current_device()
kernel[(1, )](x, 1)
kernel[(1, )](x, 8)
kernel[(1, )](x, 16)
Expand All @@ -157,7 +160,7 @@ def test_constexpr_not_callable() -> None:
def kernel(X, c: tl.constexpr):
tl.store(X, 2)

x = torch.empty(1, dtype=torch.int32, device='cuda')
x = torch.empty(1, dtype=torch.int32, device='xpu')
error = False
try:
kernel[(1, )](x, c="str")
Expand All @@ -180,12 +183,12 @@ def kernel_add(a, b, o, N: tl.constexpr):
tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx))

args = [
torch.randn(32, dtype=torch.float32, device="cuda"),
torch.randn(32, dtype=torch.float32, device="cuda"),
torch.randn(32, dtype=torch.float32, device="cuda"),
torch.randn(32, dtype=torch.float32, device="xpu"),
torch.randn(32, dtype=torch.float32, device="xpu"),
torch.randn(32, dtype=torch.float32, device="xpu"),
32,
]
device = torch.cuda.current_device()
device = torch.xpu.current_device()
assert len(kernel_add.cache[device]) == 0
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, ))
assert len(kernel_add.cache[device]) == 1
Expand All @@ -203,7 +206,7 @@ def kernel_add(a, b, o, N: tl.constexpr):
tl.device_assert(idx < 32, "idx < 32")
tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx))

device = torch.cuda.current_device()
device = torch.xpu.current_device()
assert len(kernel_add.cache[device]) == 0
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, ))
assert len(kernel_add.cache[device]) == 1
Expand All @@ -229,7 +232,7 @@ def test_jit_noinline() -> None:
def kernel_add_device(a, b, o, N: tl.constexpr):
add_fn(a, b, o, N)

device = torch.cuda.current_device()
device = torch.xpu.current_device()
assert len(kernel_add_device.cache[device]) == 0
kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, ))
assert len(kernel_add_device.cache[device]) == 1
Expand Down Expand Up @@ -257,3 +260,5 @@ def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask)
tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp0, xmask)

reset_tmp_dir()
4 changes: 2 additions & 2 deletions python/test/unit/runtime/test_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):

tracemalloc.start()
try:
inp = torch.randn(10, device='cuda')
out = torch.randn(10, device='cuda')
inp = torch.randn(10, device='xpu')
out = torch.randn(10, device='xpu')
kernel[(10, )](inp, out, 10, XBLOCK=16)
gc.collect()
begin, _ = tracemalloc.get_traced_memory()
Expand Down
12 changes: 10 additions & 2 deletions python/test/unit/runtime/test_subproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@
import os
import shutil

import pytest
import torch

import triton
import triton.language as tl
from triton.compiler import ASTSource

# FIXME remove this once Triton L0 queue and IPEX SYCL queue can be synchronized through events
torch.xpu.enable_sync_mode()

tmpdir = ".tmp"


Expand All @@ -30,10 +34,12 @@ def kernel_sub(a, b, o, N: tl.constexpr):
signature={0: "*fp32", 1: "*fp32", 2: "*fp32"},
attrs=attrs,
)
triton.compile(src=src, target=("cuda", capability))
triton.compile(src=src, target=("xpu", capability))


def test_compile_in_subproc() -> None:
pytest.skip("FIXME: Port get_device_capability to XPU")

major, minor = torch.cuda.get_device_capability(0)
cc = major * 10 + minor
config = triton.compiler.AttrsDescriptor(tuple(range(4)), (), (), ())
Expand All @@ -55,10 +61,12 @@ def kernel_dot(Z):
tl.store(Z + offs, z)

src = ASTSource(fn=kernel_dot, signature={0: "*fp32"}, attrs=attrs, constants=dict())
triton.compile(src=src, target=("cuda", capability))
triton.compile(src=src, target=("xpu", capability))


def test_compile_in_forked_subproc() -> None:
pytest.skip("FIXME: Port get_device_capability to XPU")

reset_tmp_dir()
major, minor = torch.cuda.get_device_capability(0)
capability = major * 10 + minor
Expand Down
14 changes: 12 additions & 2 deletions scripts/test-triton.sh
Original file line number Diff line number Diff line change
Expand Up @@ -96,18 +96,22 @@ function run_core_tests {
echo "***************************************************"
echo "****** Running Triton Core tests ******"
echo "***************************************************"
CORE_TEST_DIR=$TRITON_PROJ/python/test/unit/language
CORE_TEST_DIR=$TRITON_PROJ/python/test/unit
if [ ! -d "${CORE_TEST_DIR}" ]; then
echo "Not found '${CORE_TEST_DIR}'. Build Triton please" ; exit 3
fi
cd $CORE_TEST_DIR

cd $CORE_TEST_DIR/language
TRITON_DISABLE_LINE_INFO=1 python3 -m pytest --verbose --device xpu --ignore=test_line_info.py --ignore=test_subprocess.py
if [ $? -ne 0 ]; then
echo "FAILED: return code $?" ; exit $?
fi

# run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0
TRITON_DISABLE_LINE_INFO=0 python3 -m pytest --verbose --device xpu test_line_info.py
if [ $? -ne 0 ]; then
echo "FAILED: return code $?" ; exit $?
fi

python3 assert_helper.py device_assert
if [ $? -ne 0 ]; then
Expand All @@ -117,6 +121,12 @@ function run_core_tests {
if [ $? -ne 0 ]; then
echo "FAILED: return code $?" ; exit $?
fi

cd $CORE_TEST_DIR/runtime
TRITON_DISABLE_LINE_INFO=1 python3 -m pytest --verbose
if [ $? -ne 0 ]; then
echo "FAILED: return code $?" ; exit $?
fi
}

function run_tutorial_test {
Expand Down

0 comments on commit 4c3a0d1

Please sign in to comment.