Skip to content

Commit

Permalink
Merge pull request #9 from SimeonEhrig/ruleNoNvccHost
Browse files Browse the repository at this point in the history
add filter rule, which forbids nvcc as host compiler
  • Loading branch information
SimeonEhrig authored Feb 8, 2024
2 parents 47e1bf6 + b130658 commit 4decd5a
Show file tree
Hide file tree
Showing 11 changed files with 189 additions and 39 deletions.
9 changes: 8 additions & 1 deletion bashi/filter_backend.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
"""Filter rules basing on backend names and versions."""
"""Filter rules basing on backend names and versions.
All rules implemented in this filter have an identifier that begins with "b" and follows a number.
Examples: b1, b42, b678 ...
These identifiers are used in the test names, for example, to make it clear which test is testing
which rule.
"""

from typing import Optional, IO
from bashi.types import ParameterValueTuple
Expand Down
33 changes: 33 additions & 0 deletions bashi/filter_chain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""Contains default filter chain and avoids circular import"""

from typeguard import typechecked
from bashi.types import FilterFunction

from bashi.filter_compiler_name import compiler_name_filter
from bashi.filter_compiler_version import compiler_version_filter
from bashi.filter_backend import backend_filter
from bashi.filter_software_dependency import software_dependency_filter


@typechecked
def get_default_filter_chain(
custom_filter_function: FilterFunction = lambda _: True,
) -> FilterFunction:
"""Concatenate the bashi filter functions in the default order and return them as one function
with a single entry point.
Args:
custom_filter_function (FilterFunction): This function is added as the last filter level and
allows the user to add custom filter rules without having to create the entire filter
chain from scratch. Defaults to lambda_:True.
Returns:
FilterFunction: The filter function chain, which can be directly used in bashi.FilterAdapter
"""
return (
lambda row: compiler_name_filter(row)
and compiler_version_filter(row)
and backend_filter(row)
and software_dependency_filter(row)
and custom_filter_function(row)
)
20 changes: 19 additions & 1 deletion bashi/filter_compiler_name.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
"""Filter rules basing on host and device compiler names."""
"""Filter rules basing on host and device compiler names.
All rules implemented in this filter have an identifier that begins with "n" and follows a number.
Examples: n1, n42, n678 ...
These identifiers are used in the test names, for example, to make it clear which test is testing
which rule.
"""

from typing import Optional, IO
from bashi.types import ParameterValueTuple
from bashi.globals import * # pylint: disable=wildcard-import,unused-wildcard-import
from bashi.utils import reason


def compiler_name_filter(
Expand All @@ -19,4 +28,13 @@ def compiler_name_filter(
Returns:
bool: True, if parameter-value-tuple is valid.
"""
# Rule: n1
# NVCC as HOST_COMPILER is not allow
# this rule will be never used, because of an implementation detail of the covertable library
# it is not possible to add NVCC as HOST_COMPILER and filter out afterwards
# this rule is only used by bashi-verify
if HOST_COMPILER in row and row[HOST_COMPILER].name == NVCC:
reason(output, "nvcc is not allowed as host compiler")
return False

return True
9 changes: 8 additions & 1 deletion bashi/filter_compiler_version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
"""Filter rules basing on host and device compiler names and versions."""
"""Filter rules basing on host and device compiler names and versions.
All rules implemented in this filter have an identifier that begins with "v" and follows a number.
Examples: v1, v42, v678 ...
These identifiers are used in the test names, for example, to make it clear which test is testing
which rule.
"""

from typing import Optional, IO
from bashi.types import ParameterValueTuple
Expand Down
9 changes: 8 additions & 1 deletion bashi/filter_software_dependency.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
"""Filter rules handling software dependencies and compiler settings."""
"""Filter rules handling software dependencies and compiler settings.
All rules implemented in this filter have an identifier that begins with "sw" and follows a number.
Examples: sw1, sw42, sw678 ...
These identifiers are used in the test names, for example, to make it clear which test is testing
which rule.
"""

from typing import Optional, IO
from bashi.types import ParameterValueTuple
Expand Down
2 changes: 1 addition & 1 deletion bashi/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
Combination,
CombinationList,
)
from bashi.utils import get_default_filter_chain
from bashi.filter_chain import get_default_filter_chain


def generate_combination_list(
Expand Down
47 changes: 18 additions & 29 deletions bashi/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Different helper functions for bashi"""

from typing import Dict, List, IO, Union
from typing import Dict, List, IO, Union, Optional
from collections import OrderedDict
import dataclasses
import sys
Expand All @@ -16,10 +16,6 @@
CombinationList,
FilterFunction,
)
from bashi.filter_compiler_name import compiler_name_filter
from bashi.filter_compiler_version import compiler_version_filter
from bashi.filter_backend import backend_filter
from bashi.filter_software_dependency import software_dependency_filter


@dataclasses.dataclass
Expand Down Expand Up @@ -82,30 +78,6 @@ def __call__(self, row: List[ParameterValue]) -> bool:
return self.filter_func(ordered_row)


@typechecked
def get_default_filter_chain(
custom_filter_function: FilterFunction = lambda _: True,
) -> FilterFunction:
"""Concatenate the bashi filter functions in the default order and return them as one function
with a single entry point.
Args:
custom_filter_function (FilterFunction): This function is added as the last filter level and
allows the user to add custom filter rules without having to create the entire filter
chain from scratch. Defaults to lambda_:True.
Returns:
FilterFunction: The filter function chain, which can be directly used in bashi.FilterAdapter
"""
return (
lambda row: compiler_name_filter(row)
and compiler_version_filter(row)
and backend_filter(row)
and software_dependency_filter(row)
and custom_filter_function(row)
)


@typechecked
def create_parameter_value_pair( # pylint: disable=too-many-arguments
parameter1: str,
Expand Down Expand Up @@ -321,3 +293,20 @@ def check_parameter_value_pair_in_combination_list(
missing_expected_param = True

return not missing_expected_param


def reason(output: Optional[IO[str]], msg: str):
"""Write the message to output if it is not None. This function is used
in filter functions to print additional information about filter decisions.
Args:
output (Optional[IO[str]]): IO object. For example, can be io.StringIO, sys.stdout or
sys.stderr
msg (str): the message
"""
if output:
print(
msg,
file=output,
end="",
)
5 changes: 5 additions & 0 deletions bashi/versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ def get_parameter_value_matrix() -> ParameterValueMatrix:
for compiler_type in [HOST_COMPILER, DEVICE_COMPILER]:
param_val_matrix[compiler_type] = []
for sw_name, sw_versions in extended_version.items():
# do not add NVCC as HOST_COMPILER
# filtering out all NVCC as HOST_COMPILER later does not work with the covertable
# library
if compiler_type == HOST_COMPILER and sw_name == NVCC:
continue
if sw_name in COMPILERS:
for sw_version in sw_versions:
param_val_matrix[compiler_type].append(
Expand Down
3 changes: 2 additions & 1 deletion tests/test_filter_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from collections import OrderedDict
import packaging.version as pkv
from bashi.types import ParameterValue, ParameterValueTuple, FilterFunction
from bashi.utils import get_default_filter_chain, FilterAdapter
from bashi.utils import FilterAdapter
from bashi.filter_chain import get_default_filter_chain


class TestFilterChain(unittest.TestCase):
Expand Down
79 changes: 79 additions & 0 deletions tests/test_nvcc_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# pylint: disable=missing-docstring
import unittest
import io

from collections import OrderedDict as OD
from utils_test import parse_param_val as ppv
from bashi.globals import * # pylint: disable=wildcard-import,unused-wildcard-import
from bashi.filter_compiler_name import compiler_name_filter


class TestNvccHostCompilerFilter(unittest.TestCase):
def test_valid_combination_rule_n1(self):
self.assertTrue(
compiler_name_filter(
OD({HOST_COMPILER: ppv((GCC, 10)), DEVICE_COMPILER: ppv((NVCC, 11.2))})
)
)

# version should not matter
self.assertTrue(
compiler_name_filter(
OD({HOST_COMPILER: ppv((CLANG, 0)), DEVICE_COMPILER: ppv((NVCC, 0))})
)
)

self.assertTrue(
compiler_name_filter(
OD(
{
HOST_COMPILER: ppv((CLANG, 0)),
DEVICE_COMPILER: ppv((NVCC, 0)),
CMAKE: ppv((CMAKE, "3.23")),
BOOST: ppv((BOOST, "1.81")),
}
)
)
)

# if HOST_COMPILER does not exist in the row, it should pass because HOST_COMPILER can be
# added at the next round
self.assertTrue(
compiler_name_filter(
OD(
{
DEVICE_COMPILER: ppv((NVCC, 0)),
CMAKE: ppv((CMAKE, "3.23")),
BOOST: ppv((BOOST, "1.81")),
}
)
)
)

self.assertTrue(compiler_name_filter(OD()))

def test_invalid_combination_rule_n1(self):
self.assertFalse(
compiler_name_filter(
OD({HOST_COMPILER: ppv((NVCC, 11.2)), DEVICE_COMPILER: ppv((NVCC, 11.2))})
)
)

self.assertFalse(
compiler_name_filter(
OD({HOST_COMPILER: ppv((NVCC, 11.2)), DEVICE_COMPILER: ppv((GCC, 11))})
)
)

self.assertFalse(
compiler_name_filter(
OD({HOST_COMPILER: ppv((NVCC, 12.2)), DEVICE_COMPILER: ppv((HIPCC, 5.1))})
)
)

self.assertFalse(compiler_name_filter(OD({HOST_COMPILER: ppv((NVCC, 10.2))})))

def test_reason_rule_n1(self):
reason_msg = io.StringIO()
self.assertFalse(compiler_name_filter(OD({HOST_COMPILER: ppv((NVCC, 10.2))}), reason_msg))
self.assertEqual(reason_msg.getvalue(), "nvcc is not allowed as host compiler")
12 changes: 8 additions & 4 deletions tests/test_params_value_matrix_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,16 @@ def test_all_params_in(self):
def test_number_host_device_compiler(self):
extended_versions = VERSIONS.copy()
extended_versions[CLANG_CUDA] = extended_versions[CLANG]
number_of_compilers = 0
number_of_host_compilers = 0
for compiler in COMPILERS:
number_of_compilers += len(extended_versions[compiler])
if compiler != NVCC:
number_of_host_compilers += len(extended_versions[compiler])

self.assertEqual(len(self.param_val_matrix[HOST_COMPILER]), number_of_compilers)
self.assertEqual(len(self.param_val_matrix[DEVICE_COMPILER]), number_of_compilers)
# NVCC is only as device compiler added
number_of_device_compilers = number_of_host_compilers + len(extended_versions[NVCC])

self.assertEqual(len(self.param_val_matrix[HOST_COMPILER]), number_of_host_compilers)
self.assertEqual(len(self.param_val_matrix[DEVICE_COMPILER]), number_of_device_compilers)

def test_number_of_backends(self):
for backend in BACKENDS:
Expand Down

0 comments on commit 4decd5a

Please sign in to comment.