Skip to content

Commit

Permalink
add filter rule that nvcc version only supports up to a specific clan…
Browse files Browse the repository at this point in the history
…g version
  • Loading branch information
SimeonEhrig committed Feb 14, 2024
1 parent 615ae80 commit 8f5fe2d
Show file tree
Hide file tree
Showing 5 changed files with 255 additions and 15 deletions.
28 changes: 24 additions & 4 deletions bashi/filter_compiler_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typeguard import typechecked
from bashi.globals import * # pylint: disable=wildcard-import,unused-wildcard-import
from bashi.types import Parameter, ParameterValueTuple
from bashi.versions import NVCC_GCC_MAX_VERSION
from bashi.versions import NVCC_GCC_MAX_VERSION, NVCC_CLANG_MAX_VERSION
from bashi.utils import reason


Expand Down Expand Up @@ -73,9 +73,9 @@ def compiler_version_filter(
# latest gcc compiler version
if row[DEVICE_COMPILER].version <= NVCC_GCC_MAX_VERSION[0].nvcc:
# check the maximum supported gcc version for the given nvcc version
for comb in NVCC_GCC_MAX_VERSION:
if row[DEVICE_COMPILER].version >= comb.nvcc:
if row[HOST_COMPILER].version > comb.host:
for nvcc_gcc_comb in NVCC_GCC_MAX_VERSION:
if row[DEVICE_COMPILER].version >= nvcc_gcc_comb.nvcc:
if row[HOST_COMPILER].version > nvcc_gcc_comb.host:
reason(
output,
f"nvcc {row[DEVICE_COMPILER].version} "
Expand All @@ -84,4 +84,24 @@ def compiler_version_filter(
return False
break

if HOST_COMPILER in row and row[HOST_COMPILER].name == CLANG:
# Rule: v3
# remove all unsupported nvcc clang version combinations
# define which is the latest supported clang compiler for a nvcc version

# if a nvcc version is not supported by bashi, assume that the version supports the
# latest clang compiler version
if row[DEVICE_COMPILER].version <= NVCC_CLANG_MAX_VERSION[0].nvcc:
# check the maximum supported gcc version for the given nvcc version
for nvcc_clang_comb in NVCC_CLANG_MAX_VERSION:
if row[DEVICE_COMPILER].version >= nvcc_clang_comb.nvcc:
if row[HOST_COMPILER].version > nvcc_clang_comb.host:
reason(
output,
f"nvcc {row[DEVICE_COMPILER].version} "
f"does not support clang {row[HOST_COMPILER].version}",
)
return False
break

return True
38 changes: 32 additions & 6 deletions bashi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
ParameterValueSingle,
ParameterValueTuple,
)
from bashi.versions import COMPILERS, VERSIONS, NVCC_GCC_MAX_VERSION
from bashi.versions import COMPILERS, VERSIONS, NVCC_GCC_MAX_VERSION, NVCC_CLANG_MAX_VERSION
from bashi.globals import * # pylint: disable=wildcard-import,unused-wildcard-import


Expand Down Expand Up @@ -390,13 +390,39 @@ def get_expected_bashi_parameter_value_pairs(
gcc_versions = [packaging.version.parse(str(v)) for v in VERSIONS[GCC]]
gcc_versions.sort()
for nvcc_version in nvcc_versions:
for max_nvcc_gcc_version in NVCC_GCC_MAX_VERSION:
if nvcc_version >= max_nvcc_gcc_version.nvcc:
for gcc_version in gcc_versions:
if gcc_version > max_nvcc_gcc_version.host:
for max_nvcc_clang_version in NVCC_GCC_MAX_VERSION:
if nvcc_version >= max_nvcc_clang_version.nvcc:
for clang_version in gcc_versions:
if clang_version > max_nvcc_clang_version.host:
remove_parameter_value_pair(
to_remove=create_parameter_value_pair(
HOST_COMPILER, GCC, gcc_version, DEVICE_COMPILER, NVCC, nvcc_version
HOST_COMPILER,
GCC,
clang_version,
DEVICE_COMPILER,
NVCC,
nvcc_version,
),
parameter_value_pairs=param_val_pair_list,
)
break

clang_versions = [packaging.version.parse(str(v)) for v in VERSIONS[CLANG]]
clang_versions.sort()

for nvcc_version in nvcc_versions:
for max_nvcc_clang_version in NVCC_CLANG_MAX_VERSION:
if nvcc_version >= max_nvcc_clang_version.nvcc:
for clang_version in clang_versions:
if clang_version > max_nvcc_clang_version.host:
remove_parameter_value_pair(
to_remove=create_parameter_value_pair(
HOST_COMPILER,
CLANG,
clang_version,
DEVICE_COMPILER,
NVCC,
nvcc_version,
),
parameter_value_pairs=param_val_pair_list,
)
Expand Down
22 changes: 22 additions & 0 deletions bashi/versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,28 @@ def __str__(self) -> str:
]
NVCC_GCC_MAX_VERSION.sort(reverse=True)

# define the maximum supported clang version for a specific nvcc version
# the latest supported nvcc version must be added, even if the supported clang version does not
# increase
# e.g.:
# NvccHostSupport("12.3", "16"),
# NvccHostSupport("12.2", "15"),
# NvccHostSupport("12.1", "15"),
NVCC_CLANG_MAX_VERSION: List[NvccHostSupport] = [
NvccHostSupport("12.3", "16"),
NvccHostSupport("12.2", "15"),
NvccHostSupport("12.1", "15"),
NvccHostSupport("12.0", "14"),
NvccHostSupport("11.6", "13"),
NvccHostSupport("11.4", "12"),
NvccHostSupport("11.2", "11"),
NvccHostSupport("11.1", "10"),
NvccHostSupport("11.0", "9"),
NvccHostSupport("10.1", "8"),
NvccHostSupport("10.0", "6"),
]
NVCC_CLANG_MAX_VERSION.sort(reverse=True)


def get_parameter_value_matrix() -> ParameterValueMatrix:
"""Generates a parameter-value-matrix from all supported compilers, softwares and compilation
Expand Down
12 changes: 9 additions & 3 deletions example/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from typing import List
import os
import sys
from bashi.generator import generate_combination_list
from bashi.utils import (
get_expected_bashi_parameter_value_pairs,
Expand Down Expand Up @@ -225,8 +226,13 @@ def create_yaml(combination_list: CombinationList):
parameter_value_matrix=param_matrix, custom_filter=custom_filter
)

print("verify combination-list")
verify(comb_list, param_matrix)

create_yaml(comb_list)
print(f"number of combinations: {len(comb_list)}")

print("verify combination-list")
if verify(comb_list, param_matrix):
print("verification passed")
sys.exit(0)

print("verification failed")
sys.exit(1)
170 changes: 168 additions & 2 deletions tests/test_nvcc_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def test_valid_multi_row_entries_gcc_rule_v2(self):
ALPAKA_ACC_GPU_CUDA_ENABLE: ppv((ALPAKA_ACC_GPU_CUDA_ENABLE, 12.1)),
DEVICE_COMPILER: ppv((NVCC, 12.1)),
ALPAKA_ACC_CPU_B_SEQ_T_SEQ_ENABLE: ppv(
(ALPAKA_ACC_CPU_B_SEQ_T_SEQ_ENABLE, "1.0.0")
(ALPAKA_ACC_CPU_B_SEQ_T_SEQ_ENABLE, ON)
),
}
)
Expand Down Expand Up @@ -470,7 +470,7 @@ def test_invalid_multi_row_entries_gcc_rule_v2(self):
ALPAKA_ACC_GPU_CUDA_ENABLE: ppv((ALPAKA_ACC_GPU_CUDA_ENABLE, 11.8)),
DEVICE_COMPILER: ppv((NVCC, 11.8)),
ALPAKA_ACC_CPU_B_SEQ_T_SEQ_ENABLE: ppv(
(ALPAKA_ACC_CPU_B_SEQ_T_SEQ_ENABLE, "1.0.0")
(ALPAKA_ACC_CPU_B_SEQ_T_SEQ_ENABLE, ON)
),
}
),
Expand Down Expand Up @@ -504,3 +504,169 @@ def test_unknown_combination_gcc_rule_v2(self):
f"nvcc {unsupported_nvcc_version} should pass the filter, because it is unknown "
"version",
)


class TestNvccSupportedClangVersion(unittest.TestCase):
def test_valid_combination_max_clang_rule_v3(self):
# change the version, if you added a new cuda release
# this test is a guard to be sure, that the following test contains the latest nvcc release
latest_covered_nvcc_release = "12.3"
self.assertEqual(
latest_covered_nvcc_release,
str(VERSIONS[NVCC][-1]),
f"The tests cases covers up to nvcc version {latest_covered_nvcc_release}.\n"
f"VERSION[NVCC] defines nvcc {VERSIONS[NVCC][-1]} as latest supported version.",
)

# add the latest supported clang version for a supported nvcc version and also the successor
# clang version
expected_results = [
("10.0", "6", True),
("10.0", "7", False),
("10.1", "8", True),
("10.1", "9", False),
("10.2", "8", True),
("10.2", "9", False),
("11.0", "9", True),
("11.0", "10", False),
("11.1", "10", True),
("11.1", "11", False),
("11.2", "11", True),
("11.2", "12", False),
# because of compiler bugs, clang is disabled for CUDA 11.3 until 11.5
# ("11.3", "11", False),
# ("11.3", "12", False),
# ("11.4", "12", False),
# ("11.4", "13", False),
# ("11.5", "12", False),
# ("11.5", "13", False),
("11.6", "13", True),
("11.6", "14", False),
("11.7", "13", True),
("11.7", "14", False),
("11.8", "13", True),
("11.8", "14", False),
("12.0", "14", True),
("12.0", "15", False),
("12.1", "15", True),
("12.1", "16", False),
("12.2", "15", True),
("12.2", "16", False),
("12.3", "16", True),
("12.3", "17", False),
]

for nvcc_version, clang_version, expected_filter_return_value in expected_results:
reason_msg = io.StringIO()
self.assertEqual(
compiler_version_filter_typechecked(
OD(
{
HOST_COMPILER: ppv((CLANG, clang_version)),
DEVICE_COMPILER: ppv((NVCC, nvcc_version)),
}
),
reason_msg,
),
expected_filter_return_value,
f"the filter for the combination of nvcc {nvcc_version} + clang {clang_version} "
f"should return {expected_filter_return_value}",
)
if not expected_filter_return_value:
self.assertEqual(
reason_msg.getvalue(),
f"nvcc {nvcc_version} " f"does not support clang {clang_version}",
)

def test_valid_multi_row_entries_clang_rule_v2(self):
self.assertTrue(
compiler_version_filter_typechecked(
OD(
{
HOST_COMPILER: ppv((CLANG, 10)),
DEVICE_COMPILER: ppv((NVCC, 11.2)),
CMAKE: ppv((CMAKE, 3.18)),
BOOST: ppv((BOOST, 1.78)),
}
)
)
)

self.assertTrue(
compiler_version_filter_typechecked(
OD(
{
HOST_COMPILER: ppv((CLANG, 12)),
ALPAKA_ACC_GPU_CUDA_ENABLE: ppv((ALPAKA_ACC_GPU_CUDA_ENABLE, 12.1)),
DEVICE_COMPILER: ppv((NVCC, 12.1)),
ALPAKA_ACC_CPU_B_SEQ_T_SEQ_ENABLE: ppv(
(ALPAKA_ACC_CPU_B_SEQ_T_SEQ_ENABLE, ON)
),
}
)
)
)

def test_invalid_multi_row_entries_clang_rule_v2(self):
reason_msg1 = io.StringIO()
self.assertFalse(
compiler_version_filter_typechecked(
OD(
{
HOST_COMPILER: ppv((CLANG, 13)),
DEVICE_COMPILER: ppv((NVCC, 11.2)),
CMAKE: ppv((CMAKE, 3.18)),
BOOST: ppv((BOOST, 1.78)),
}
),
reason_msg1,
),
)
self.assertEqual(
reason_msg1.getvalue(),
"nvcc 11.2 does not support clang 13",
)

reason_msg2 = io.StringIO()
self.assertFalse(
compiler_version_filter_typechecked(
OD(
{
HOST_COMPILER: ppv((CLANG, 16)),
ALPAKA_ACC_GPU_CUDA_ENABLE: ppv((ALPAKA_ACC_GPU_CUDA_ENABLE, 11.8)),
DEVICE_COMPILER: ppv((NVCC, 11.8)),
ALPAKA_ACC_CPU_B_SEQ_T_SEQ_ENABLE: ppv(
(ALPAKA_ACC_CPU_B_SEQ_T_SEQ_ENABLE, ON)
),
}
),
reason_msg2,
)
)
self.assertEqual(
reason_msg2.getvalue(),
"nvcc 11.8 does not support clang 16",
)

def test_unknown_combination_clang_rule_v2(self):
# test an unsupported nvcc version
# we assume, that the nvcc supports all gcc versions
unsupported_nvcc_version = 42.0
self.assertFalse(
unsupported_nvcc_version in VERSIONS[NVCC],
f"for the test, it is required that nvcc {unsupported_nvcc_version} is an unsupported "
"version",
)

self.assertTrue(
compiler_version_filter_typechecked(
OD(
{
HOST_COMPILER: ppv((CLANG, 12)),
DEVICE_COMPILER: ppv((NVCC, unsupported_nvcc_version)),
}
),
),
f"nvcc {unsupported_nvcc_version} should pass the filter, because it is unknown "
"version",
)

0 comments on commit 8f5fe2d

Please sign in to comment.