Skip to content

Commit

Permalink
Merge pull request #12 from SimeonEhrig/ruleNvccHostCompiler
Browse files Browse the repository at this point in the history
add filter rule which allows only gcc and clang as nvcc host compiler
  • Loading branch information
SimeonEhrig authored Feb 13, 2024
2 parents 20174e2 + 82859a4 commit 340d5d5
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 12 deletions.
10 changes: 10 additions & 0 deletions bashi/filter_compiler_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,14 @@ def compiler_name_filter(
reason(output, "nvcc is not allowed as host compiler")
return False

# Rule: n2
if (
DEVICE_COMPILER in row
and row[DEVICE_COMPILER].name == NVCC
and HOST_COMPILER in row
and not row[HOST_COMPILER].name in [GCC, CLANG]
):
reason(output, "only gcc and clang are allowed as nvcc host compiler")
return False

return True
50 changes: 41 additions & 9 deletions bashi/utils.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
"""Different helper functions for bashi"""

from typing import Dict, List, IO, Union, Optional
from collections import OrderedDict
import dataclasses
import sys
from typeguard import typechecked
from collections import OrderedDict
from typing import IO, Dict, List, Optional, Union

import packaging.version
from typeguard import typechecked

from bashi.types import (
CombinationList,
FilterFunction,
Parameter,
ParameterValue,
ParameterValueTuple,
ParameterValueSingle,
ParameterValuePair,
ParameterValueMatrix,
CombinationList,
FilterFunction,
ParameterValuePair,
ParameterValueSingle,
ParameterValueTuple,
)
from bashi.versions import COMPILERS
from bashi.globals import * # pylint: disable=wildcard-import,unused-wildcard-import


@dataclasses.dataclass
Expand Down Expand Up @@ -126,7 +130,7 @@ def create_parameter_value_pair( # pylint: disable=too-many-arguments
def get_expected_parameter_value_pairs(
parameter_matrix: ParameterValueMatrix,
) -> List[ParameterValuePair]:
"""Takes parameter-value-matrix an creates a list of all expected parameter-values-pairs.
"""Takes parameter-value-matrix and creates a list of all expected parameter-values-pairs.
The pair-wise generator guaranties, that each pair of two parameter-values exist in at least one
combination if no filter rules exist. Therefore the generated the generated list can be used
to verify the output of the pair-wise generator.
Expand Down Expand Up @@ -310,3 +314,31 @@ def reason(output: Optional[IO[str]], msg: str):
file=output,
end="",
)


@typechecked
def get_expected_bashi_parameter_value_pairs(
parameter_matrix: ParameterValueMatrix,
) -> List[ParameterValuePair]:
"""Takes parameter-value-matrix and creates a list of all expected parameter-values-pairs
allowed by the bashi library. First it generates a complete list of parameter-value-pairs and
then it removes all pairs that are not allowed by filter rules.
Args:
parameter_matrix (ParameterValueMatrix): matrix of parameter values
Returns:
List[ParameterValuePair]: list of all parameter-value-pairs supported by bashi
"""
param_val_pair_list = get_expected_parameter_value_pairs(parameter_matrix)

for compiler_name in set(COMPILERS) - set([GCC, CLANG, NVCC]):
remove_parameter_value_pair(
to_remove=create_parameter_value_pair(
HOST_COMPILER, compiler_name, 0, DEVICE_COMPILER, NVCC, 0
),
parameter_value_pairs=param_val_pair_list,
all_versions=True,
)

return param_val_pair_list
7 changes: 5 additions & 2 deletions tests/test_generate_combination_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from bashi.generator import generate_combination_list
from bashi.utils import (
get_expected_parameter_value_pairs,
get_expected_bashi_parameter_value_pairs,
check_parameter_value_pair_in_combination_list,
remove_parameter_value_pair,
create_parameter_value_pair,
Expand Down Expand Up @@ -182,7 +183,7 @@ def custom_filter(row: ParameterValueTuple) -> bool:
class TestGeneratorRealData(unittest.TestCase):
def test_generator_without_custom_filter(self):
param_val_matrix = get_parameter_value_matrix()
expected_param_val_pairs = get_expected_parameter_value_pairs(param_val_matrix)
expected_param_val_pairs = get_expected_bashi_parameter_value_pairs(param_val_matrix)

comb_list = generate_combination_list(param_val_matrix)

Expand All @@ -203,7 +204,9 @@ def custom_filter(row: ParameterValueTuple) -> bool:
return True

param_val_matrix = get_parameter_value_matrix()
reduced_expected_param_val_pairs = get_expected_parameter_value_pairs(param_val_matrix)
reduced_expected_param_val_pairs = get_expected_bashi_parameter_value_pairs(
param_val_matrix
)

self.assertTrue(
remove_parameter_value_pair(
Expand Down
92 changes: 91 additions & 1 deletion tests/test_nvcc_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from bashi.filter_compiler_name import compiler_name_filter_typechecked


class TestNvccHostCompilerFilter(unittest.TestCase):
class TestNoNvccHostCompiler(unittest.TestCase):
def test_valid_combination_rule_n1(self):
self.assertTrue(
compiler_name_filter_typechecked(
Expand Down Expand Up @@ -79,3 +79,93 @@ def test_reason_rule_n1(self):
compiler_name_filter_typechecked(OD({HOST_COMPILER: ppv((NVCC, 10.2))}), reason_msg)
)
self.assertEqual(reason_msg.getvalue(), "nvcc is not allowed as host compiler")


class TestSupportedNvccHostCompiler(unittest.TestCase):
def test_invalid_combination_rule_n2(self):
for compiler_name in [CLANG_CUDA, HIPCC, ICPX, NVCC]:
for compiler_version in ["0", "13", "32a2"]:
reason_msg = io.StringIO()
self.assertFalse(
compiler_name_filter_typechecked(
OD(
{
HOST_COMPILER: ppv((compiler_name, compiler_version)),
DEVICE_COMPILER: ppv((NVCC, "12.3")),
}
),
reason_msg,
)
)
# NVCC is filtered by rule n1
if compiler_name != NVCC:
self.assertEqual(
reason_msg.getvalue(),
"only gcc and clang are allowed as nvcc host compiler",
)

self.assertFalse(
compiler_name_filter_typechecked(
OD(
{
HOST_COMPILER: ppv((HIPCC, "5.3")),
DEVICE_COMPILER: ppv((NVCC, "12.3")),
CMAKE: ppv((CMAKE, "3.18")),
BOOST: ppv((BOOST, "1.81.0")),
}
)
)
)
self.assertFalse(
compiler_name_filter_typechecked(
OD(
{
HOST_COMPILER: ppv((HIPCC, "5.3")),
ALPAKA_ACC_CPU_B_SEQ_T_SEQ_ENABLE: ppv(
(ALPAKA_ACC_CPU_B_TBB_T_SEQ_ENABLE, "1.0.0")
),
DEVICE_COMPILER: ppv((NVCC, "12.3")),
}
)
)
)

def test_valid_combination_rule_n2(self):
for compiler_name in [GCC, CLANG]:
for compiler_version in ["0", "13", "7b2"]:
self.assertTrue(
compiler_name_filter_typechecked(
OD(
{
HOST_COMPILER: ppv((compiler_name, compiler_version)),
DEVICE_COMPILER: ppv((NVCC, "12.3")),
}
)
)
)

self.assertTrue(
compiler_name_filter_typechecked(
OD(
{
HOST_COMPILER: ppv((GCC, "13")),
DEVICE_COMPILER: ppv((NVCC, "11.5")),
BOOST: ppv((BOOST, "1.84.0")),
CMAKE: ppv((CMAKE, "3.23")),
}
)
)
)
self.assertTrue(
compiler_name_filter_typechecked(
OD(
{
HOST_COMPILER: ppv((CLANG, "14")),
ALPAKA_ACC_CPU_B_SEQ_T_SEQ_ENABLE: ppv(
(ALPAKA_ACC_CPU_B_TBB_T_SEQ_ENABLE, "1.0.0")
),
DEVICE_COMPILER: ppv((NVCC, "10.1")),
}
)
)
)

0 comments on commit 340d5d5

Please sign in to comment.