Skip to content

Commit

Permalink
add filter rule which disallow different versions for host and device…
Browse files Browse the repository at this point in the history
… compiler

- the only exception is, if the device compiler name is nvcc
- improved nvcc filter rule tests
  • Loading branch information
SimeonEhrig committed Feb 13, 2024
1 parent de4b0f7 commit 5211756
Show file tree
Hide file tree
Showing 5 changed files with 189 additions and 18 deletions.
2 changes: 1 addition & 1 deletion bashi/filter_compiler_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def get_required_parameters() -> List[Parameter]:
Returns:
List[Parameter]: list of checked parameters
"""
return [HOST_COMPILER]
return [HOST_COMPILER, DEVICE_COMPILER]


@typechecked
Expand Down
20 changes: 15 additions & 5 deletions bashi/filter_compiler_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

from typing import Optional, IO, List
from typeguard import typechecked
from bashi.globals import * # pylint: disable=wildcard-import,unused-wildcard-import
from bashi.types import Parameter, ParameterValueTuple
from bashi.utils import reason


def get_required_parameters() -> List[Parameter]:
Expand All @@ -18,7 +20,7 @@ def get_required_parameters() -> List[Parameter]:
Returns:
List[Parameter]: list of checked parameters
"""
return []
return [HOST_COMPILER, DEVICE_COMPILER]


@typechecked
Expand All @@ -32,11 +34,9 @@ def compiler_version_filter_typechecked(
return compiler_version_filter(row, output)


# TODO(SimeonEhrig): remove disable=unused-argument
# only required for the CI at the moment
def compiler_version_filter(
row: ParameterValueTuple, # pylint: disable=unused-argument
output: Optional[IO[str]] = None, # pylint: disable=unused-argument
row: ParameterValueTuple,
output: Optional[IO[str]] = None,
) -> bool:
"""Filter rules basing on host and device compiler names and versions.
Expand All @@ -50,4 +50,14 @@ def compiler_version_filter(
bool: True, if parameter-value-tuple is valid.
"""

# Rule: v1
if (
DEVICE_COMPILER in row
and row[DEVICE_COMPILER].name != NVCC
and HOST_COMPILER in row
and row[HOST_COMPILER].version != row[DEVICE_COMPILER].version
):
reason(output, "host and device compiler version must be the same (except for nvcc)")
return False

return True
24 changes: 23 additions & 1 deletion bashi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
ParameterValueSingle,
ParameterValueTuple,
)
from bashi.versions import COMPILERS
from bashi.versions import COMPILERS, VERSIONS
from bashi.globals import * # pylint: disable=wildcard-import,unused-wildcard-import


Expand Down Expand Up @@ -332,6 +332,9 @@ def get_expected_bashi_parameter_value_pairs(
"""
param_val_pair_list = get_expected_parameter_value_pairs(parameter_matrix)

extend_versions = VERSIONS.copy()
extend_versions[CLANG_CUDA] = extend_versions[CLANG]

# remove all combinations where nvcc is device compiler and the host compiler is not gcc or
# clang
for compiler_name in set(COMPILERS) - set([GCC, CLANG, NVCC]):
Expand Down Expand Up @@ -361,4 +364,23 @@ def get_expected_bashi_parameter_value_pairs(
all_versions=True,
)

# remove all combinations, where host and device compiler version are different except the
# compiler name is nvcc
for compiler_name in set(COMPILERS) - set([NVCC]):
for compiler_version1 in extend_versions[compiler_name]:
for compiler_version2 in extend_versions[compiler_name]:
if compiler_version1 != compiler_version2:
remove_parameter_value_pair(
to_remove=create_parameter_value_pair(
HOST_COMPILER,
compiler_name,
compiler_version1,
DEVICE_COMPILER,
compiler_name,
compiler_version2,
),
parameter_value_pairs=param_val_pair_list,
all_versions=False,
)

return param_val_pair_list
121 changes: 121 additions & 0 deletions tests/test_filter_compiler_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# pylint: disable=missing-docstring
import unittest

import io
from collections import OrderedDict as OD
from utils_test import parse_param_val as ppv
from bashi.globals import * # pylint: disable=wildcard-import,unused-wildcard-import
from bashi.filter_compiler_version import compiler_version_filter_typechecked


class TestEmptyRow(unittest.TestCase):
def test_empty_row_shall_always_pass(self):
self.assertTrue(compiler_version_filter_typechecked(OD()))


class TestHostDeviceCompilerSameVersion(unittest.TestCase):
def test_valid_combination_rule_v1(self):
for comb in [
(ppv((GCC, 10)), ppv((GCC, 10))),
(ppv((ICPX, "2040.1.0")), ppv((ICPX, "2040.1.0"))),
(ppv((HIPCC, 5.5)), ppv((HIPCC, 5.5))),
(ppv((CLANG, 13)), ppv((CLANG, 13))),
(ppv((CLANG_CUDA, 17)), ppv((CLANG_CUDA, 17))),
]:
self.assertTrue(
compiler_version_filter_typechecked(
OD({HOST_COMPILER: comb[0], DEVICE_COMPILER: comb[1]})
),
f"host compiler and device compiler version are not the same: {comb[0]} != {comb[1]}",
)

self.assertTrue(
compiler_version_filter_typechecked(
OD(
{
HOST_COMPILER: ppv((CLANG_CUDA, 10)),
ALPAKA_ACC_GPU_CUDA_ENABLE: ppv((ALPAKA_ACC_GPU_CUDA_ENABLE, 11.2)),
DEVICE_COMPILER: ppv((CLANG_CUDA, 10)),
CMAKE: ppv((CMAKE, 3.18)),
}
)
),
)

self.assertTrue(
compiler_version_filter_typechecked(
OD(
{
HOST_COMPILER: ppv((CLANG, 14)),
DEVICE_COMPILER: ppv((CLANG, 14)),
ALPAKA_ACC_CPU_B_OMP2_T_SEQ_ENABLE: ppv(
(ALPAKA_ACC_CPU_B_OMP2_T_SEQ_ENABLE, "1.0.0")
),
CMAKE: ppv((CMAKE, 3.24)),
BOOST: ppv((BOOST, 1.78)),
}
)
),
)

def test_invalid_combination_rule_v1(self):
for comb in [
(ppv((GCC, 10)), ppv((GCC, 11))),
(ppv((ICPX, "2023.1.0")), ppv((ICPX, "2040.1.0"))),
(ppv((HIPCC, 6)), ppv((HIPCC, 5.5))),
(ppv((CLANG, 0)), ppv((CLANG, 13))),
(ppv((CLANG_CUDA, "4a3")), ppv((CLANG_CUDA, 17))),
]:
reason_msg = io.StringIO()

self.assertFalse(
compiler_version_filter_typechecked(
OD({HOST_COMPILER: comb[0], DEVICE_COMPILER: comb[1]}), reason_msg
),
f"same host compiler and device compiler version should pass: {comb[0]} and {comb[1]}",
)
self.assertEqual(
reason_msg.getvalue(),
"host and device compiler version must be the same (except for nvcc)",
)

reason_msg_multi1 = io.StringIO()
self.assertFalse(
compiler_version_filter_typechecked(
OD(
{
HOST_COMPILER: ppv((CLANG_CUDA, 10)),
ALPAKA_ACC_GPU_CUDA_ENABLE: ppv((ALPAKA_ACC_GPU_CUDA_ENABLE, 11.2)),
DEVICE_COMPILER: ppv((CLANG_CUDA, 11)),
CMAKE: ppv((CMAKE, 3.18)),
},
),
reason_msg_multi1,
),
)
self.assertEqual(
reason_msg_multi1.getvalue(),
"host and device compiler version must be the same (except for nvcc)",
)

reason_msg_multi2 = io.StringIO()
self.assertFalse(
compiler_version_filter_typechecked(
OD(
{
HOST_COMPILER: ppv((GCC, 15)),
DEVICE_COMPILER: ppv((GCC, 10)),
ALPAKA_ACC_CPU_B_OMP2_T_SEQ_ENABLE: ppv(
(ALPAKA_ACC_CPU_B_OMP2_T_SEQ_ENABLE, "1.0.0")
),
CMAKE: ppv((CMAKE, 3.24)),
BOOST: ppv((BOOST, 1.78)),
}
),
reason_msg_multi2,
),
)
self.assertEqual(
reason_msg_multi2.getvalue(),
"host and device compiler version must be the same (except for nvcc)",
)
40 changes: 29 additions & 11 deletions tests/test_nvcc_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,32 +51,37 @@ def test_valid_combination_rule_n1(self):
)

def test_invalid_combination_rule_n1(self):
reason_msg1 = io.StringIO()
self.assertFalse(
compiler_name_filter_typechecked(
OD({HOST_COMPILER: ppv((NVCC, 11.2)), DEVICE_COMPILER: ppv((NVCC, 11.2))})
OD({HOST_COMPILER: ppv((NVCC, 11.2)), DEVICE_COMPILER: ppv((NVCC, 11.2))}),
reason_msg1,
)
)
self.assertEqual(reason_msg1.getvalue(), "nvcc is not allowed as host compiler")

reason_msg2 = io.StringIO()
self.assertFalse(
compiler_name_filter_typechecked(
OD({HOST_COMPILER: ppv((NVCC, 11.2)), DEVICE_COMPILER: ppv((GCC, 11))})
OD({HOST_COMPILER: ppv((NVCC, 11.2)), DEVICE_COMPILER: ppv((GCC, 11))}), reason_msg2
)
)
self.assertEqual(reason_msg2.getvalue(), "nvcc is not allowed as host compiler")

reason_msg3 = io.StringIO()
self.assertFalse(
compiler_name_filter_typechecked(
OD({HOST_COMPILER: ppv((NVCC, 12.2)), DEVICE_COMPILER: ppv((HIPCC, 5.1))})
OD({HOST_COMPILER: ppv((NVCC, 12.2)), DEVICE_COMPILER: ppv((HIPCC, 5.1))}),
reason_msg3,
)
)
self.assertEqual(reason_msg3.getvalue(), "nvcc is not allowed as host compiler")

self.assertFalse(compiler_name_filter_typechecked(OD({HOST_COMPILER: ppv((NVCC, 10.2))})))

def test_reason_rule_n1(self):
reason_msg = io.StringIO()
reason_msg4 = io.StringIO()
self.assertFalse(
compiler_name_filter_typechecked(OD({HOST_COMPILER: ppv((NVCC, 10.2))}), reason_msg)
compiler_name_filter_typechecked(OD({HOST_COMPILER: ppv((NVCC, 10.2))}), reason_msg4)
)
self.assertEqual(reason_msg.getvalue(), "nvcc is not allowed as host compiler")
self.assertEqual(reason_msg4.getvalue(), "nvcc is not allowed as host compiler")


class TestSupportedNvccHostCompiler(unittest.TestCase):
Expand All @@ -102,6 +107,7 @@ def test_invalid_combination_rule_n2(self):
"only gcc and clang are allowed as nvcc host compiler",
)

reason_msg1 = io.StringIO()
self.assertFalse(
compiler_name_filter_typechecked(
OD(
Expand All @@ -111,9 +117,16 @@ def test_invalid_combination_rule_n2(self):
CMAKE: ppv((CMAKE, "3.18")),
BOOST: ppv((BOOST, "1.81.0")),
}
)
),
reason_msg1,
)
)
self.assertEqual(
reason_msg1.getvalue(),
"only gcc and clang are allowed as nvcc host compiler",
)

reason_msg2 = io.StringIO()
self.assertFalse(
compiler_name_filter_typechecked(
OD(
Expand All @@ -124,9 +137,14 @@ def test_invalid_combination_rule_n2(self):
),
DEVICE_COMPILER: ppv((NVCC, "12.3")),
}
)
),
reason_msg2,
)
)
self.assertEqual(
reason_msg2.getvalue(),
"only gcc and clang are allowed as nvcc host compiler",
)

def test_valid_combination_rule_n2(self):
for compiler_name in [GCC, CLANG]:
Expand Down

0 comments on commit 5211756

Please sign in to comment.