Skip to content

Commit

Permalink
Merge pull request #13 from SimeonEhrig/RuleSameCompilerName
Browse files Browse the repository at this point in the history
add filter rule which disallow different names for host and device compiler
  • Loading branch information
SimeonEhrig authored Feb 13, 2024
2 parents 340d5d5 + b52ed76 commit de4b0f7
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 5 deletions.
10 changes: 10 additions & 0 deletions bashi/filter_compiler_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,14 @@ def compiler_name_filter(
reason(output, "only gcc and clang are allowed as nvcc host compiler")
return False

# Rule: n3
if (
DEVICE_COMPILER in row
and row[DEVICE_COMPILER].name != NVCC
and HOST_COMPILER in row
and row[HOST_COMPILER].name != row[DEVICE_COMPILER].name
):
reason(output, "host and device compiler name must be the same (except for nvcc)")
return False

return True
20 changes: 20 additions & 0 deletions bashi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,8 @@ def get_expected_bashi_parameter_value_pairs(
"""
param_val_pair_list = get_expected_parameter_value_pairs(parameter_matrix)

# remove all combinations where nvcc is device compiler and the host compiler is not gcc or
# clang
for compiler_name in set(COMPILERS) - set([GCC, CLANG, NVCC]):
remove_parameter_value_pair(
to_remove=create_parameter_value_pair(
Expand All @@ -341,4 +343,22 @@ def get_expected_bashi_parameter_value_pairs(
all_versions=True,
)

# remove all combinations, where host and device compiler name are different except the device
# compiler name is nvcc
for host_compiler_name in set(COMPILERS) - set([NVCC]):
for device_compiler_name in set(COMPILERS) - set([NVCC]):
if host_compiler_name != device_compiler_name:
remove_parameter_value_pair(
to_remove=create_parameter_value_pair(
HOST_COMPILER,
host_compiler_name,
0,
DEVICE_COMPILER,
device_compiler_name,
0,
),
parameter_value_pairs=param_val_pair_list,
all_versions=True,
)

return param_val_pair_list
121 changes: 121 additions & 0 deletions tests/test_filter_compiler_name.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# pylint: disable=missing-docstring
import unittest

import io
from collections import OrderedDict as OD
from utils_test import parse_param_val as ppv
from bashi.globals import * # pylint: disable=wildcard-import,unused-wildcard-import
from bashi.filter_compiler_name import compiler_name_filter_typechecked


class TestEmptyRow(unittest.TestCase):
def test_empty_row_shall_always_pass(self):
self.assertTrue(compiler_name_filter_typechecked(OD()))


class TestHostDeviceCompilerSameName(unittest.TestCase):
def test_valid_combination_rule_n3(self):
for comb in [
(ppv((GCC, 10)), ppv((GCC, 10))),
(ppv((GCC, 1)), ppv((GCC, 10))), # version is not important
(ppv((HIPCC, 6.0)), ppv((HIPCC, 6.0))),
(ppv((ICPX, 3.0)), ppv((ICPX, 6.0))),
(ppv((CLANG, 3.0)), ppv((NVCC, 6.0))), # nvcc has device compiler is an exception
(ppv((GCC, 3.0)), ppv((NVCC, 6.0))),
]:
self.assertTrue(
compiler_name_filter_typechecked(
OD({HOST_COMPILER: comb[0], DEVICE_COMPILER: comb[1]})
),
f"host compiler and device compiler name are not the same: {comb[0]} != {comb[1]}",
)

self.assertTrue(
compiler_name_filter_typechecked(
OD(
{
HOST_COMPILER: ppv((CLANG_CUDA, 10)),
ALPAKA_ACC_GPU_CUDA_ENABLE: ppv((ALPAKA_ACC_GPU_CUDA_ENABLE, 11.2)),
DEVICE_COMPILER: ppv((CLANG_CUDA, 10)),
CMAKE: ppv((CMAKE, 3.18)),
}
)
),
)

self.assertTrue(
compiler_name_filter_typechecked(
OD(
{
HOST_COMPILER: ppv((CLANG, 10)),
DEVICE_COMPILER: ppv((CLANG, 10)),
ALPAKA_ACC_CPU_B_OMP2_T_SEQ_ENABLE: ppv(
(ALPAKA_ACC_CPU_B_OMP2_T_SEQ_ENABLE, "1.0.0")
),
CMAKE: ppv((CMAKE, 3.24)),
BOOST: ppv((BOOST, 1.78)),
}
)
),
)

def test_invalid_combination_rule_n3(self):
for comb in [
(ppv((GCC, 10)), ppv((CLANG, 10))),
(ppv((HIPCC, 1)), ppv((GCC, 10))), # version is not important
(ppv((HIPCC, 6.0)), ppv((ICPX, 6.0))),
(ppv((ICPX, 3.0)), ppv((CLANG_CUDA, 6.0))),
]:
reason_msg = io.StringIO()

self.assertFalse(
compiler_name_filter_typechecked(
OD({HOST_COMPILER: comb[0], DEVICE_COMPILER: comb[1]}), reason_msg
),
f"same host compiler and device compiler name should pass: {comb[0]} and {comb[1]}",
)
self.assertEqual(
reason_msg.getvalue(),
"host and device compiler name must be the same (except for nvcc)",
)

reason_msg_multi1 = io.StringIO()
self.assertFalse(
compiler_name_filter_typechecked(
OD(
{
HOST_COMPILER: ppv((CLANG_CUDA, 10)),
ALPAKA_ACC_GPU_CUDA_ENABLE: ppv((ALPAKA_ACC_GPU_CUDA_ENABLE, 11.2)),
DEVICE_COMPILER: ppv((HIPCC, 10)),
CMAKE: ppv((CMAKE, 3.18)),
},
),
reason_msg_multi1,
),
)
self.assertEqual(
reason_msg_multi1.getvalue(),
"host and device compiler name must be the same (except for nvcc)",
)

reason_msg_multi2 = io.StringIO()
self.assertFalse(
compiler_name_filter_typechecked(
OD(
{
HOST_COMPILER: ppv((GCC, 15)),
DEVICE_COMPILER: ppv((CLANG, 10)),
ALPAKA_ACC_CPU_B_OMP2_T_SEQ_ENABLE: ppv(
(ALPAKA_ACC_CPU_B_OMP2_T_SEQ_ENABLE, "1.0.0")
),
CMAKE: ppv((CMAKE, 3.24)),
BOOST: ppv((BOOST, 1.78)),
}
),
reason_msg_multi2,
),
)
self.assertEqual(
reason_msg_multi2.getvalue(),
"host and device compiler name must be the same (except for nvcc)",
)
13 changes: 10 additions & 3 deletions tests/test_generate_combination_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from bashi.versions import get_parameter_value_matrix
from bashi.generator import generate_combination_list
from bashi.utils import (
get_expected_parameter_value_pairs,
get_expected_bashi_parameter_value_pairs,
check_parameter_value_pair_in_combination_list,
remove_parameter_value_pair,
Expand All @@ -33,13 +32,21 @@ def setUpClass(cls):
[(GCC, 10), (GCC, 11), (GCC, 12), (CLANG, 16), (CLANG, 17)]
)
cls.param_matrix[DEVICE_COMPILER] = parse_param_vals(
[(NVCC, 11.2), (NVCC, 12.0), (GCC, 10), (GCC, 11)]
[
(NVCC, 11.2),
(NVCC, 12.0),
(GCC, 10),
(GCC, 11),
(GCC, 12),
(CLANG, 16),
(CLANG, 17),
]
)
cls.param_matrix[CMAKE] = parse_param_vals([(CMAKE, 3.22), (CMAKE, 3.23)])
cls.param_matrix[BOOST] = parse_param_vals([(BOOST, 1.81), (BOOST, 1.82), (BOOST, 1.83)])

cls.generated_parameter_value_pairs: List[ParameterValuePair] = (
get_expected_parameter_value_pairs(cls.param_matrix)
get_expected_bashi_parameter_value_pairs(cls.param_matrix)
)

def test_generator_without_custom_filter(self):
Expand Down
2 changes: 0 additions & 2 deletions tests/test_nvcc_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@ def test_valid_combination_rule_n1(self):
)
)

self.assertTrue(compiler_name_filter_typechecked(OD()))

def test_invalid_combination_rule_n1(self):
self.assertFalse(
compiler_name_filter_typechecked(
Expand Down

0 comments on commit de4b0f7

Please sign in to comment.