Skip to content

Commit

Permalink
Use pynvjitlink for CUDA 12+ MVC (#13650)
Browse files Browse the repository at this point in the history
Fixes #12822

This PR provides minor version compatibility in the CUDA 12.x range through `nvjitlink` via the preliminary [nvjiitlink python binding](https://github.com/gmarkall/nvjitlink). Thus far this PR merely leverages a local installation of the library and should not be merged until `nvjitlink` is hosted on `conda-forge` and cuDF's dependencies are adjusted accordingly, likely as part of this PR.

Authors:
  - https://github.com/brandon-b-miller
  - Ashwin Srinath (https://github.com/shwina)

Approvers:
  - Bradley Dice (https://github.com/bdice)
  - Ashwin Srinath (https://github.com/shwina)

URL: #13650
  • Loading branch information
brandon-b-miller authored Nov 20, 2023
1 parent 3ef13d0 commit 823d321
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 72 deletions.
99 changes: 99 additions & 0 deletions python/cudf/cudf/tests/test_mvc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
import subprocess
import sys

import pytest

IS_CUDA_11 = False
IS_CUDA_12 = False
try:
from ptxcompiler.patch import safe_get_versions
except ModuleNotFoundError:
from cudf.utils._ptxcompiler import safe_get_versions

# do not test cuda 12 if pynvjitlink isn't present
HAVE_PYNVJITLINK = False
try:
import pynvjitlink # noqa: F401

HAVE_PYNVJITLINK = True
except ModuleNotFoundError:
pass


versions = safe_get_versions()
driver_version, runtime_version = versions

if (11, 0) <= driver_version < (12, 0):
IS_CUDA_11 = True
if (12, 0) <= driver_version < (13, 0):
IS_CUDA_12 = True


TEST_BODY = """
@numba.cuda.jit
def test_kernel(x):
id = numba.cuda.grid(1)
if id < len(x):
x[id] += 1
s = cudf.Series([1, 2, 3])
with _CUDFNumbaConfig():
test_kernel.forall(len(s))(s)
"""

CUDA_11_TEST = (
"""
import numba.cuda
import cudf
from cudf.utils._numba import _CUDFNumbaConfig, patch_numba_linker_cuda_11
patch_numba_linker_cuda_11()
"""
+ TEST_BODY
)


CUDA_12_TEST = (
"""
import numba.cuda
import cudf
from cudf.utils._numba import _CUDFNumbaConfig
from pynvjitlink.patch import (
patch_numba_linker as patch_numba_linker_pynvjitlink,
)
patch_numba_linker_pynvjitlink()
"""
+ TEST_BODY
)


@pytest.mark.parametrize(
"test",
[
pytest.param(
CUDA_11_TEST,
marks=pytest.mark.skipif(
not IS_CUDA_11,
reason="Minor Version Compatibility test for CUDA 11",
),
),
pytest.param(
CUDA_12_TEST,
marks=pytest.mark.skipif(
not IS_CUDA_12 or not HAVE_PYNVJITLINK,
reason="Minor Version Compatibility test for CUDA 12",
),
),
],
)
def test_numba_mvc(test):
cp = subprocess.run(
[sys.executable, "-c", test],
capture_output=True,
cwd="/",
)

assert cp.returncode == 0
48 changes: 0 additions & 48 deletions python/cudf/cudf/tests/test_numba_import.py

This file was deleted.

53 changes: 29 additions & 24 deletions python/cudf/cudf/utils/_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,19 @@

from numba import config as numba_config

try:
from pynvjitlink.patch import (
patch_numba_linker as patch_numba_linker_pynvjitlink,
)
except ImportError:

def patch_numba_linker_pynvjitlink():
warnings.warn(
"CUDA Toolkit is newer than CUDA driver. "
"Numba features will not work in this configuration. "
)


CC_60_PTX_FILE = os.path.join(
os.path.dirname(__file__), "../core/udf/shim_60.ptx"
)
Expand Down Expand Up @@ -65,7 +78,7 @@ def _get_ptx_file(path, prefix):
return regular_result[1]


def _patch_numba_mvc():
def patch_numba_linker_cuda_11():
# Enable the config option for minor version compatibility
numba_config.CUDA_ENABLE_MINOR_VERSION_COMPATIBILITY = 1

Expand Down Expand Up @@ -106,29 +119,19 @@ def _setup_numba():
versions = safe_get_versions()
if versions != NO_DRIVER:
driver_version, runtime_version = versions
if driver_version >= (12, 0) and runtime_version > driver_version:
warnings.warn(
f"Using CUDA toolkit version {runtime_version} with CUDA "
f"driver version {driver_version} requires minor version "
"compatibility, which is not yet supported for CUDA "
"driver versions 12.0 and above. It is likely that many "
"cuDF operations will not work in this state. Please "
f"install CUDA toolkit version {driver_version} to "
"continue using cuDF."
)
else:
# Support MVC for all CUDA versions in the 11.x range
ptx_toolkit_version = _get_cuda_version_from_ptx_file(
CC_60_PTX_FILE
)
# Numba thinks cubinlinker is only needed if the driver is older
# than the CUDA runtime, but when PTX files are present, it might
# also need to patch because those PTX files may be compiled by
# a CUDA version that is newer than the driver as well
if (driver_version < ptx_toolkit_version) or (
driver_version < runtime_version
):
_patch_numba_mvc()
ptx_toolkit_version = _get_cuda_version_from_ptx_file(CC_60_PTX_FILE)

# MVC is required whenever any PTX is newer than the driver
# This could be the shipped PTX file or the PTX emitted by
# the version of NVVM on the user system, the latter aligning
# with the runtime version
if (driver_version < ptx_toolkit_version) or (
driver_version < runtime_version
):
if driver_version < (12, 0):
patch_numba_linker_cuda_11()
else:
patch_numba_linker_pynvjitlink()


def _get_cuda_version_from_ptx_file(path):
Expand Down Expand Up @@ -171,6 +174,8 @@ def _get_cuda_version_from_ptx_file(path):
"7.8": (11, 8),
"8.0": (12, 0),
"8.1": (12, 1),
"8.2": (12, 2),
"8.3": (12, 3),
}

cuda_ver = ver_map.get(version)
Expand Down

0 comments on commit 823d321

Please sign in to comment.