Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify compiler name and version filter #19

Merged
merged 3 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions bashi/filter_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
from typeguard import typechecked
from bashi.types import FilterFunction

from bashi.filter_compiler_name import compiler_name_filter
from bashi.filter_compiler_version import compiler_version_filter
from bashi.filter_compiler import compiler_filter
from bashi.filter_backend import backend_filter
from bashi.filter_software_dependency import software_dependency_filter

Expand All @@ -25,8 +24,7 @@ def get_default_filter_chain(
FilterFunction: The filter function chain, which can be directly used in bashi.FilterAdapter
"""
return (
lambda row: compiler_name_filter(row)
and compiler_version_filter(row)
lambda row: compiler_filter(row)
and backend_filter(row)
and software_dependency_filter(row)
and custom_filter_function(row)
Expand Down
78 changes: 52 additions & 26 deletions bashi/filter_compiler_version.py → bashi/filter_compiler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Filter rules basing on host and device compiler names and versions.

All rules implemented in this filter have an identifier that begins with "v" and follows a number.
Examples: v1, v42, v678 ...
All rules implemented in this filter have an identifier that begins with "c" and follows a number.
Examples: c1, c42, c678 ...

These identifiers are used in the test names, for example, to make it clear which test is testing
which rule.
Expand All @@ -15,6 +15,9 @@
from bashi.versions import NVCC_GCC_MAX_VERSION, NVCC_CLANG_MAX_VERSION
from bashi.utils import reason

# uncomment me for debugging
# from bashi.utils import print_row_nice


def get_required_parameters() -> List[Parameter]:
"""Return list of parameters which will be checked in the filter.
Expand All @@ -26,18 +29,19 @@ def get_required_parameters() -> List[Parameter]:


@typechecked
def compiler_version_filter_typechecked(
def compiler_filter_typechecked(
row: ParameterValueTuple,
output: Optional[IO[str]] = None,
) -> bool:
"""Type-checked version of compiler_version_filter(). Type checking has a big performance cost,
which is why the non type-checked version is used for the pairwise generator.
"""Type-checked version of compiler_filter(). Type checking has a big performance cost, which
is why the non type-checked version is used for the pairwise generator.
"""
return compiler_version_filter(row, output)
return compiler_filter(row, output)


# pylint: disable=too-many-branches
def compiler_version_filter(
# pylint: disable=too-many-return-statements
def compiler_filter(
row: ParameterValueTuple,
output: Optional[IO[str]] = None,
) -> bool:
Expand All @@ -52,22 +56,43 @@ def compiler_version_filter(
Returns:
bool: True, if parameter-value-tuple is valid.
"""
# uncomment me for debugging
# print_row_nice(row)

# Rule: v1
if (
DEVICE_COMPILER in row
and row[DEVICE_COMPILER].name != NVCC
and HOST_COMPILER in row
and row[HOST_COMPILER].version != row[DEVICE_COMPILER].version
):
reason(output, "host and device compiler version must be the same (except for nvcc)")
# Rule: c1
# NVCC as HOST_COMPILER is not allow
# this rule will be never used, because of an implementation detail of the covertable library
# it is not possible to add NVCC as HOST_COMPILER and filter out afterwards
# this rule is only used by bashi-verify
if HOST_COMPILER in row and row[HOST_COMPILER].name == NVCC:
reason(output, "nvcc is not allowed as host compiler")
return False

if HOST_COMPILER in row and DEVICE_COMPILER in row:
if NVCC in (row[HOST_COMPILER].name, row[DEVICE_COMPILER].name):
# Rule: c2
if row[HOST_COMPILER].name not in (GCC, CLANG):
reason(output, "only gcc and clang are allowed as nvcc host compiler")
return False
else:
# Rule: c3
if row[HOST_COMPILER].name != row[DEVICE_COMPILER].name:
reason(output, "host and device compiler name must be the same (except for nvcc)")
return False

# Rule: c4
if row[HOST_COMPILER].version != row[DEVICE_COMPILER].version:
reason(
output,
"host and device compiler version must be the same (except for nvcc)",
)
return False

# now idea, how remove nested blocks without hitting the performance
# pylint: disable=too-many-nested-blocks
if DEVICE_COMPILER in row and row[DEVICE_COMPILER].name == NVCC:
if HOST_COMPILER in row and row[HOST_COMPILER].name == GCC:
# Rule: v2
# Rule: c5
# remove all unsupported nvcc gcc version combinations
# define which is the latest supported gcc compiler for a nvcc version

Expand All @@ -87,7 +112,7 @@ def compiler_version_filter(
break

if HOST_COMPILER in row and row[HOST_COMPILER].name == CLANG:
# Rule: v4
# Rule: c7
if row[DEVICE_COMPILER].version >= pkv.parse("11.3") and row[
DEVICE_COMPILER
].version <= pkv.parse("11.5"):
Expand All @@ -97,7 +122,7 @@ def compiler_version_filter(
)
return False

# Rule: v3
# Rule: c6
# remove all unsupported nvcc clang version combinations
# define which is the latest supported clang compiler for a nvcc version

Expand All @@ -116,17 +141,18 @@ def compiler_version_filter(
return False
break

# Rule: v5
# Rule: c8
# clang-cuda 13 and older is not supported
# this rule will be never used, because of an implementation detail of the covertable library
# it is not possible to add the clang-cuda versions and filter it out afterwards
# this rule is only used by bashi-verify
if (
DEVICE_COMPILER in row
and row[DEVICE_COMPILER].name == CLANG_CUDA
and row[DEVICE_COMPILER].version < pkv.parse("14")
):
reason(output, "all clang versions older than 14 are disabled as CUDA Compiler")
return False
for compiler in (HOST_COMPILER, DEVICE_COMPILER):
if (
compiler in row
and row[compiler].name == CLANG_CUDA
and row[compiler].version < pkv.parse("14")
):
reason(output, "all clang versions older than 14 are disabled as CUDA Compiler")
return False

return True
81 changes: 0 additions & 81 deletions bashi/filter_compiler_name.py

This file was deleted.

38 changes: 2 additions & 36 deletions bashi/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

from typing import Dict, List
from collections import OrderedDict
import copy
import packaging.version as pkv

from covertable import make # type: ignore

Expand Down Expand Up @@ -34,44 +32,12 @@ def generate_combination_list(
Returns:
CombinationList: combination-list
"""
# use local version to do not modify parameter_value_matrix
local_param_val_mat = copy.deepcopy(parameter_value_matrix)

filter_chain = get_default_filter_chain(custom_filter)

def host_compiler_filter(param_val: ParameterValue) -> bool:
# Rule: n1
# remove nvcc as host compiler
if param_val.name == NVCC:
return False
# Rule: v5
# remove clang-cuda older than 14
if param_val.name == CLANG_CUDA and param_val.version < pkv.parse("14"):
return False

return True

def device_compiler_filter(param_val: ParameterValue) -> bool:
# Rule: v5
# remove clang-cuda older than 14
if param_val.name == CLANG_CUDA and param_val.version < pkv.parse("14"):
return False

return True

pre_filters = {HOST_COMPILER: host_compiler_filter, DEVICE_COMPILER: device_compiler_filter}

# some filter rules requires that specific parameter-values are already removed from the
# parameter-value-matrix
# otherwise the covertable library throws an error
for param, filter_func in pre_filters.items():
if param in local_param_val_mat:
local_param_val_mat[param] = list(filter(filter_func, local_param_val_mat[param]))

comb_list: CombinationList = []

all_pairs: List[Dict[Parameter, ParameterValue]] = make(
factors=local_param_val_mat,
factors=parameter_value_matrix,
length=2,
pre_filter=filter_chain,
) # type: ignore
Expand All @@ -81,7 +47,7 @@ def device_compiler_filter(param_val: ParameterValue) -> bool:
tmp_comb: Combination = OrderedDict({})
# covertable does not keep the ordering of the parameters
# therefore we sort it
for param in local_param_val_mat.keys():
for param in parameter_value_matrix.keys():
tmp_comb[param] = all_pair[param]
comb_list.append(tmp_comb)

Expand Down
40 changes: 40 additions & 0 deletions bashi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,3 +368,43 @@ def reason(output: Optional[IO[str]], msg: str):
file=output,
end="",
)


# do not cover code, because the function is only used for debugging
def print_row_nice(row: ParameterValueTuple, init: str = ""): # pragma: no cover
"""Prints a parameter-value-tuple in a short and nice way.

Args:
row (ParameterValueTuple): row with parameter-value-tuple
init (str, optional): Prefix of the output string. Defaults to "".
"""
s = init
short_name: dict[str, str] = {
HOST_COMPILER: "host",
DEVICE_COMPILER: "device",
ALPAKA_ACC_CPU_B_OMP2_T_SEQ_ENABLE: "bOpenMP2thread",
ALPAKA_ACC_CPU_B_SEQ_T_OMP2_ENABLE: "bOpenMP2block",
ALPAKA_ACC_CPU_B_SEQ_T_SEQ_ENABLE: "bSeq",
ALPAKA_ACC_CPU_B_SEQ_T_THREADS_ENABLE: "bThreads",
ALPAKA_ACC_CPU_B_TBB_T_SEQ_ENABLE: "bTBB",
ALPAKA_ACC_GPU_CUDA_ENABLE: "bCUDA",
ALPAKA_ACC_GPU_HIP_ENABLE: "bHIP",
ALPAKA_ACC_SYCL_ENABLE: "bSYCL",
CXX_STANDARD: "c++",
}
nice_version: dict[packaging.version.Version, str] = {
ON_VER: "ON",
OFF_VER: "OFF",
}

for param, val in row.items():
if param in [HOST_COMPILER, DEVICE_COMPILER]:
s += (
f"{short_name.get(param, param)}={short_name.get(val.name, val.name)}-"
f"{nice_version.get(val.version, str(val.version))} "
)
else:
s += (
f"{short_name.get(param, param)}={nice_version.get(val.version, str(val.version))} "
)
print(s)
17 changes: 4 additions & 13 deletions bashi/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
from bashi.globals import * # pylint: disable=wildcard-import,unused-wildcard-import
from bashi.types import ParameterValue, ParameterValueTuple
from bashi.versions import is_supported_version
import bashi.filter_compiler_name
import bashi.filter_compiler_version
import bashi.filter_compiler
import bashi.filter_backend
import bashi.filter_software_dependency

Expand Down Expand Up @@ -244,17 +243,9 @@ def check_filter_chain(row: ParameterValueTuple) -> bool:
all_true = 0
all_true += int(
check_single_filter(
bashi.filter_compiler_name.compiler_name_filter_typechecked,
bashi.filter_compiler.compiler_filter,
row,
bashi.filter_compiler_name.get_required_parameters(),
)
)

all_true += int(
check_single_filter(
bashi.filter_compiler_version.compiler_version_filter_typechecked,
row,
bashi.filter_compiler_version.get_required_parameters(),
bashi.filter_compiler.get_required_parameters(),
)
)
all_true += int(
Expand All @@ -273,7 +264,7 @@ def check_filter_chain(row: ParameterValueTuple) -> bool:
)

# each filter add a one, if it was successful
return all_true == 4
return all_true == 3


def main() -> None:
Expand Down
Loading
Loading