Skip to content

Commit

Permalink
generate_combination_list() filters now the parameter-value-matrix
Browse files Browse the repository at this point in the history
- it is not necessary to avoid specific parameter-values such like nvcc as host-compiler anymore
- filters also clang-cuda 13 and older
- replace .copy() with deepcopy() solve some memory problems
  • Loading branch information
SimeonEhrig committed Feb 15, 2024
1 parent d0b22df commit 5d025ad
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 37 deletions.
42 changes: 36 additions & 6 deletions bashi/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from typing import Dict, List
from collections import OrderedDict
import copy
import packaging.version as pkv

from covertable import make # type: ignore

Expand All @@ -13,6 +15,7 @@
Combination,
CombinationList,
)
from bashi.globals import * # pylint: disable=wildcard-import,unused-wildcard-import
from bashi.filter_chain import get_default_filter_chain


Expand All @@ -31,17 +34,44 @@ def generate_combination_list(
Returns:
CombinationList: combination-list
"""
# use local version to do not modify parameter_value_matrix
local_param_val_mat = copy.deepcopy(parameter_value_matrix)

filter_chain = get_default_filter_chain(custom_filter)

# TODO(SimeonEhrig): add filter function here, which remove NVCC as host compiler and
# CLANG-CUDA 13 and older as compiler
# the covertable throws an error, if the filter rule removes to much possibilities in an early
# filter
def host_compiler_filter(param_val: ParameterValue) -> bool:
# Rule: n1
# remove nvcc as host compiler
if param_val.name == NVCC:
return False
# Rule: v5
# remove clang-cuda older than 14
if param_val.name == CLANG_CUDA and param_val.version < pkv.parse("14"):
return False

return True

def device_compiler_filter(param_val: ParameterValue) -> bool:
# Rule: v5
# remove clang-cuda older than 14
if param_val.name == CLANG_CUDA and param_val.version < pkv.parse("14"):
return False

return True

pre_filters = {HOST_COMPILER: host_compiler_filter, DEVICE_COMPILER: device_compiler_filter}

# some filter rules requires that specific parameter-values are already removed from the
# parameter-value-matrix
# otherwise the covertable library throws an error
for param, filter_func in pre_filters.items():
if param in local_param_val_mat:
local_param_val_mat[param] = list(filter(filter_func, local_param_val_mat[param]))

comb_list: CombinationList = []

all_pairs: List[Dict[Parameter, ParameterValue]] = make(
factors=parameter_value_matrix,
factors=local_param_val_mat,
length=2,
pre_filter=filter_chain,
) # type: ignore
Expand All @@ -51,7 +81,7 @@ def generate_combination_list(
tmp_comb: Combination = OrderedDict({})
# covertable does not keep the ordering of the parameters
# therefore we sort it
for param in parameter_value_matrix.keys():
for param in local_param_val_mat.keys():
tmp_comb[param] = all_pair[param]
comb_list.append(tmp_comb)

Expand Down
42 changes: 35 additions & 7 deletions bashi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import dataclasses
import sys
import copy
from collections import OrderedDict
from typing import IO, Dict, List, Optional, Union

Expand Down Expand Up @@ -316,7 +317,9 @@ def reason(output: Optional[IO[str]], msg: str):
)


# TODO(SimeonEhrig) modularize the function
# pylint: disable=too-many-branches
# pylint: disable=too-many-locals
@typechecked
def get_expected_bashi_parameter_value_pairs(
parameter_matrix: ParameterValueMatrix,
Expand All @@ -331,9 +334,34 @@ def get_expected_bashi_parameter_value_pairs(
Returns:
List[ParameterValuePair]: list of all parameter-value-pairs supported by bashi
"""
param_val_pair_list = get_expected_parameter_value_pairs(parameter_matrix)
local_parameter_matrix = copy.deepcopy(parameter_matrix)

extend_versions = VERSIONS.copy()
def remove_host_compiler_nvcc(param_val: ParameterValue) -> bool:
if param_val.name == NVCC:
return False
return True

# remove nvcc as host compiler
local_parameter_matrix[HOST_COMPILER] = list(
filter(remove_host_compiler_nvcc, local_parameter_matrix[HOST_COMPILER])
)

# remove clang-cuda 13 and older
def remove_unsupported_clang_cuda_version(param_val: ParameterValue) -> bool:
if param_val.name == CLANG_CUDA and param_val.version < packaging.version.parse("14"):
return False
return True

local_parameter_matrix[HOST_COMPILER] = list(
filter(remove_unsupported_clang_cuda_version, local_parameter_matrix[HOST_COMPILER])
)
local_parameter_matrix[DEVICE_COMPILER] = list(
filter(remove_unsupported_clang_cuda_version, local_parameter_matrix[DEVICE_COMPILER])
)

param_val_pair_list = get_expected_parameter_value_pairs(local_parameter_matrix)

extend_versions = copy.deepcopy(VERSIONS)
extend_versions[CLANG_CUDA] = extend_versions[CLANG]

# remove all combinations where nvcc is device compiler and the host compiler is not gcc or
Expand Down Expand Up @@ -390,15 +418,15 @@ def get_expected_bashi_parameter_value_pairs(
gcc_versions = [packaging.version.parse(str(v)) for v in VERSIONS[GCC]]
gcc_versions.sort()
for nvcc_version in nvcc_versions:
for max_nvcc_clang_version in NVCC_GCC_MAX_VERSION:
if nvcc_version >= max_nvcc_clang_version.nvcc:
for clang_version in gcc_versions:
if clang_version > max_nvcc_clang_version.host:
for max_nvcc_gcc_version in NVCC_GCC_MAX_VERSION:
if nvcc_version >= max_nvcc_gcc_version.nvcc:
for gcc_version in gcc_versions:
if gcc_version > max_nvcc_gcc_version.host:
remove_parameter_value_pair(
to_remove=create_parameter_value_pair(
HOST_COMPILER,
GCC,
clang_version,
gcc_version,
DEVICE_COMPILER,
NVCC,
nvcc_version,
Expand Down
17 changes: 6 additions & 11 deletions bashi/versions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Provides all supported software versions"""

import copy
from typing import Dict, List, Union
from collections import OrderedDict
from typeguard import typechecked
Expand Down Expand Up @@ -121,23 +122,17 @@ def get_parameter_value_matrix() -> ParameterValueMatrix:
"""
param_val_matrix: ParameterValueMatrix = OrderedDict()

extended_version = VERSIONS.copy()
extended_version = copy.deepcopy(VERSIONS)
extended_version[CLANG_CUDA] = extended_version[CLANG]

for compiler_type in [HOST_COMPILER, DEVICE_COMPILER]:
param_val_matrix[compiler_type] = []
for sw_name, sw_versions in extended_version.items():
# do not add NVCC as HOST_COMPILER
# filtering out all NVCC as HOST_COMPILER later does not work with the covertable
# library
if compiler_type == HOST_COMPILER and sw_name == NVCC:
continue
if sw_name in COMPILERS:
for sw_version in sw_versions:
if not (sw_name == CLANG_CUDA and pkv.parse(str(sw_version)) < pkv.parse("14")):
param_val_matrix[compiler_type].append(
ParameterValue(sw_name, pkv.parse(str(sw_version)))
)
param_val_matrix[compiler_type].append(
ParameterValue(sw_name, pkv.parse(str(sw_version)))
)

for backend in BACKENDS:
if backend == ALPAKA_ACC_GPU_CUDA_ENABLE:
Expand Down Expand Up @@ -180,7 +175,7 @@ def is_supported_version(name: ValueName, version: ValueVersion) -> bool:
if name not in known_names:
raise ValueError(f"Unknown software name: {name}")

local_versions = VERSIONS.copy()
local_versions = copy.deepcopy(VERSIONS)

local_versions[CLANG_CUDA] = local_versions[CLANG]
local_versions[ALPAKA_ACC_GPU_CUDA_ENABLE] = [OFF]
Expand Down
3 changes: 2 additions & 1 deletion tests/test_expected_parameter_value_pairs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# pylint: disable=missing-docstring
import unittest
import copy
from typing import List, Dict
from collections import OrderedDict
import io
Expand Down Expand Up @@ -635,7 +636,7 @@ def test_remove_parameter_value_pair_all_versions(self):

expected_number_of_reduced_pairs = len(reduced_param_value_pairs)

expected_reduced_param_value_pairs = reduced_param_value_pairs.copy()
expected_reduced_param_value_pairs = copy.deepcopy(reduced_param_value_pairs)

# remove single value to verify that default flag is working
example_single_pair = create_parameter_value_pair(
Expand Down
69 changes: 68 additions & 1 deletion tests/test_generate_combination_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import unittest
import os
import io
import copy
from collections import OrderedDict
import packaging.version as pkv
from utils_test import parse_param_vals
Expand Down Expand Up @@ -145,7 +146,7 @@ def custom_filter(row: ParameterValueTuple) -> bool:
parameter_value_matrix=self.param_matrix, custom_filter=custom_filter
)

reduced_expected_param_val_pairs = self.generated_parameter_value_pairs.copy()
reduced_expected_param_val_pairs = copy.deepcopy(self.generated_parameter_value_pairs)
for device_compiler in self.param_matrix[DEVICE_COMPILER]:
if device_compiler.name == NVCC:
self.assertTrue(
Expand Down Expand Up @@ -241,3 +242,69 @@ def custom_filter(row: ParameterValueTuple) -> bool:
number_of_combs = len(missing_combinations_str.split("\n"))
print(f"\nnumber of missing combinations: {number_of_combs}")
raise e


class TestParameterMatrixFilter(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.param_base_matrix: ParameterValueMatrix = OrderedDict()

cls.param_base_matrix[HOST_COMPILER] = parse_param_vals(
[(GCC, 10), (GCC, 11), (GCC, 12), (CLANG, 16), (CLANG, 17)]
)
cls.param_base_matrix[DEVICE_COMPILER] = parse_param_vals(
[
(GCC, 10),
(GCC, 11),
(GCC, 12),
(CLANG, 16),
(CLANG, 17),
]
)
cls.param_base_matrix[CMAKE] = parse_param_vals([(CMAKE, 3.22), (CMAKE, 3.23)])
cls.param_base_matrix[BOOST] = parse_param_vals(
[(BOOST, 1.81), (BOOST, 1.82), (BOOST, 1.83)]
)

def test_nvcc_host_compiler_rule_n1(self):
# test if generate_combination_list() correctly handles nvcc as host compiler
param_matrix = copy.deepcopy(self.param_base_matrix)
for nvcc_version in [11.2, 11.3, 11.8, 12.0]:
param_matrix[HOST_COMPILER].append(ParameterValue(NVCC, pkv.parse(str(nvcc_version))))
param_matrix[DEVICE_COMPILER].append(ParameterValue(NVCC, pkv.parse(str(nvcc_version))))
param_matrix_before = copy.deepcopy(param_matrix)

comb_list = generate_combination_list(param_matrix)

# generate_combination_list should not modify the param_matrix
self.assertEqual(param_matrix_before, param_matrix)

self.assertTrue(
check_parameter_value_pair_in_combination_list(
comb_list, get_expected_bashi_parameter_value_pairs(param_matrix)
)
)

def test_clang_cuda_old_versions_rule_v5(self):
# test if generate_combination_list() correctly clang-cuda version 13 and older

param_matrix = copy.deepcopy(self.param_base_matrix)
for clang_cuda_version in [8, 13, 14, 17]:
param_matrix[HOST_COMPILER].append(
ParameterValue(CLANG_CUDA, pkv.parse(str(clang_cuda_version)))
)
param_matrix[DEVICE_COMPILER].append(
ParameterValue(CLANG_CUDA, pkv.parse(str(clang_cuda_version)))
)
param_matrix_before = copy.deepcopy(param_matrix)

comb_list = generate_combination_list(parameter_value_matrix=param_matrix)

# generate_combination_list should not modify the param_matrix
self.assertEqual(param_matrix_before, param_matrix)

self.assertTrue(
check_parameter_value_pair_in_combination_list(
comb_list, get_expected_bashi_parameter_value_pairs(param_matrix)
)
)
16 changes: 5 additions & 11 deletions tests/test_params_value_matrix_generator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# pylint: disable=missing-docstring
import unittest
import packaging.version as pkv
import copy
from bashi.versions import VERSIONS, get_parameter_value_matrix
from bashi.globals import * # pylint: disable=wildcard-import,unused-wildcard-import

Expand All @@ -25,23 +25,17 @@ def test_all_params_in(self):
)

def test_number_host_device_compiler(self):
extended_versions = VERSIONS.copy()
extended_versions = copy.deepcopy(VERSIONS)
# filter clang-cuda 13 and older because the pair-wise generator cannot filter it out
# afterwards
extended_versions[CLANG_CUDA] = list(
filter(
lambda clang_version: pkv.parse(str(clang_version)) >= pkv.parse("14"),
extended_versions[CLANG],
)
)
extended_versions[CLANG_CUDA] = extended_versions[CLANG]

number_of_host_compilers = 0
for compiler in COMPILERS:
if compiler != NVCC:
number_of_host_compilers += len(extended_versions[compiler])
number_of_host_compilers += len(extended_versions[compiler])

# NVCC is only as device compiler added
number_of_device_compilers = number_of_host_compilers + len(extended_versions[NVCC])
number_of_device_compilers = number_of_host_compilers

self.assertEqual(len(self.param_val_matrix[HOST_COMPILER]), number_of_host_compilers)
self.assertEqual(len(self.param_val_matrix[DEVICE_COMPILER]), number_of_device_compilers)
Expand Down

0 comments on commit 5d025ad

Please sign in to comment.