From 8d9764c3840e606094d46286662c5ed652161509 Mon Sep 17 00:00:00 2001 From: Simeon Ehrig Date: Wed, 7 Feb 2024 13:00:43 +0100 Subject: [PATCH] change type of ParameterValuePair to named tuple - introduce new type ParameterValueSingle - add option to remove_parameter_value_pair() remove all pairs only by given name independent of the version number --- bashi/types.py | 8 +- bashi/utils.py | 109 +++++- docs/naming.md | 5 +- tests/test_expected_parameter_value_pairs.py | 354 +++++++++++++++++-- tests/test_generate_combination_list.py | 31 +- tests/utils_test.py | 24 +- 6 files changed, 456 insertions(+), 75 deletions(-) diff --git a/bashi/types.py b/bashi/types.py index 71ef54b..8f349ce 100644 --- a/bashi/types.py +++ b/bashi/types.py @@ -10,7 +10,13 @@ ParameterValue = NamedTuple("ParameterValue", [("name", ValueName), ("version", ValueVersion)]) ParameterValueList: TypeAlias = List[ParameterValue] ParameterValueMatrix: TypeAlias = OrderedDict[Parameter, ParameterValueList] -ParameterValuePair: TypeAlias = OrderedDict[Parameter, ParameterValue] +ParameterValueSingle = NamedTuple( + "ParameterValueSingle", [("parameter", Parameter), ("parameterValue", ParameterValue)] +) +ParameterValuePair = NamedTuple( + "ParameterValuePair", + [("first", ParameterValueSingle), ("second", ParameterValueSingle)], +) ParameterValueTuple: TypeAlias = OrderedDict[Parameter, ParameterValue] Combination: TypeAlias = OrderedDict[Parameter, ParameterValue] CombinationList: TypeAlias = List[Combination] diff --git a/bashi/utils.py b/bashi/utils.py index e5b2ebc..de8033d 100644 --- a/bashi/utils.py +++ b/bashi/utils.py @@ -1,14 +1,16 @@ """Different helper functions for bashi""" -from typing import Dict, List, IO +from typing import Dict, List, IO, Union from collections import OrderedDict import dataclasses import sys from typeguard import typechecked +import packaging.version from bashi.types import ( Parameter, ParameterValue, ParameterValueTuple, + ParameterValueSingle, ParameterValuePair, ParameterValueMatrix, CombinationList, @@ -104,6 +106,50 @@ def get_default_filter_chain( ) +@typechecked +def create_parameter_value_pair( # pylint: disable=too-many-arguments + parameter1: str, + value_name1: str, + value_version1: Union[int, float, str, packaging.version.Version], + parameter2: str, + value_name2: str, + value_version2: Union[int, float, str, packaging.version.Version], +) -> ParameterValuePair: + """Create parameter-value-pair from the given arguments. Parse parameter-versions if required. + + Args: + parameter1 (str): name of the first parameter + value_name1 (str): name of the first value-name + value_version1 (Union[int, float, str, packaging.version.Version]): version of first + value-version + parameter2 (str): name of the second parameter + value_name2 (str): name of the second value-name + value_version2 (Union[int, float, str, packaging.version.Version]): version of the second + value-version + + Returns: + ParameterValuePair: parameter-value-pair + """ + if isinstance(value_version1, packaging.version.Version): + parsed_value_version1: packaging.version.Version = value_version1 + else: + parsed_value_version1: packaging.version.Version = packaging.version.parse( + str(value_version1) + ) + + if isinstance(value_version2, packaging.version.Version): + parsed_value_version2: packaging.version.Version = value_version2 + else: + parsed_value_version2: packaging.version.Version = packaging.version.parse( + str(value_version2) + ) + + return ParameterValuePair( + ParameterValueSingle(parameter1, ParameterValue(value_name1, parsed_value_version1)), + ParameterValueSingle(parameter2, ParameterValue(value_name2, parsed_value_version2)), + ) + + @typechecked def get_expected_parameter_value_pairs( parameter_matrix: ParameterValueMatrix, @@ -156,35 +202,44 @@ def _loop_over_parameter_values( """ for v1_name, v1_version in parameters[v1_parameter]: for v2_name, v2_version in parameters[v2_parameter]: - param_val_pair: ParameterValuePair = OrderedDict() - param_val_pair[v1_parameter] = ParameterValue(v1_name, v1_version) - param_val_pair[v2_parameter] = ParameterValue(v2_name, v2_version) - expected_pairs.append(param_val_pair) + expected_pairs.append( + create_parameter_value_pair( + v1_parameter, v1_name, v1_version, v2_parameter, v2_name, v2_version + ) + ) @typechecked def remove_parameter_value_pair( - to_remove: ParameterValuePair, parameter_value_pairs: List[ParameterValuePair] + to_remove: Union[ParameterValueSingle, ParameterValuePair], + parameter_value_pairs: List[ParameterValuePair], + all_versions: bool = False, ) -> bool: """Removes a parameter-value pair with one or two entries from the parameter-value-pair list. If the parameter-value-pair only has one parameter value, all parameter-value-pairs that contain the parameter value are removed. Args: - to_remove (ParameterValuePair): Parameter-value-pair to remove + to_remove (Union[ParameterValueSingle, ParameterValuePair]): Parameter-value-single or + parameter-value-pair to remove param_val_pairs (List[ParameterValuePair]): List of parameter-value-pairs. Will be modified. - + all_versions (bool): If it is `True` and `to_remove` has type of `ParameterValuePair`, + removes all parameter-value-pairs witch matches the value-names independent of the + value-version. Defaults to False. Raises: - RuntimeError: If `to_remove` 0 or more than 2 entries. + RuntimeError: If `all_versions=True` and `to_remove` is a `ParameterValueSingle` Returns: bool: True if entry was removed from parameter_value_pairs """ - if len(to_remove) == 1: + if isinstance(to_remove, ParameterValueSingle): + if all_versions: + raise RuntimeError("all_versions=True is not support for ParameterValueSingle") + return _remove_single_entry_parameter_value_pair(to_remove, parameter_value_pairs) - if len(to_remove) == 0 or len(to_remove) > 2: - raise RuntimeError("More than two parameter-values are not allowed") + if all_versions: + return _remove_parameter_value_pair_all_versions(to_remove, parameter_value_pairs) try: parameter_value_pairs.remove(to_remove) @@ -195,14 +250,34 @@ def remove_parameter_value_pair( @typechecked def _remove_single_entry_parameter_value_pair( - to_remove: ParameterValuePair, param_val_pairs: List[ParameterValuePair] + to_remove: ParameterValueSingle, param_val_pairs: List[ParameterValuePair] ) -> bool: - val_name, val_version = next(iter(to_remove.items())) + len_before = len(param_val_pairs) + + def filter_function(param_val_pair: ParameterValuePair) -> bool: + for param_val_entry in param_val_pair: + if param_val_entry == to_remove: + return False + return True + + param_val_pairs[:] = list(filter(filter_function, param_val_pairs)) + + return len_before != len(param_val_pairs) + +@typechecked +def _remove_parameter_value_pair_all_versions( + to_remove: ParameterValuePair, param_val_pairs: List[ParameterValuePair] +) -> bool: len_before = len(param_val_pairs) def filter_function(param_val_pair: ParameterValuePair) -> bool: - if val_name in param_val_pair and param_val_pair[val_name] == val_version: + if ( + param_val_pair.first.parameter == to_remove.first.parameter + and param_val_pair.second.parameter == to_remove.second.parameter + and param_val_pair.first.parameterValue.name == to_remove.first.parameterValue.name + and param_val_pair.second.parameterValue.name == to_remove.second.parameterValue.name + ): return False return True @@ -232,8 +307,8 @@ def check_parameter_value_pair_in_combination_list( missing_expected_param = False for ex_param_val_pair in parameter_value_pairs: - param1, param_val1 = list(ex_param_val_pair.items())[0] - param2, param_val2 = list(ex_param_val_pair.items())[1] + param1, param_val1 = ex_param_val_pair[0] + param2, param_val2 = ex_param_val_pair[1] found = False for comb in combination_list: # comb contains all parameters, therefore a check is not required diff --git a/docs/naming.md b/docs/naming.md index 64f2ac4..999127e 100644 --- a/docs/naming.md +++ b/docs/naming.md @@ -5,7 +5,7 @@ The [pair-wise testing](https://en.wikipedia.org/wiki/All-pairs_testing) takes a The real Python types are implemented in [types.py](../bashi/types.py) - **parameter** (`str`): A `parameter` represents a software component like the host compiler or a specific software like `CMake` or `Boost`. A `parameter` names a list of `parameter-values` and expresses how a `parameter-value` is used. -- **parameter-value** (`Tuple[value-name: str, value-version: packaging.version.Version]`): A `parameter-value` represents of a specific version of a `parameter`, for example `GCC 10`, `nvcc 12.2` or `CMake 3.28`. The pair wise generator takes on `parameter-value` of each `parameter` for is combinatorics. +- **parameter-value** (`NamedTuple[value-name: str, value-version: packaging.version.Version]`): A `parameter-value` represents of a specific version of a `parameter`, for example `GCC 10`, `nvcc 12.2` or `CMake 3.28`. The pair wise generator takes on `parameter-value` of each `parameter` for is combinatorics. - **value-name** (`str`): The `value-name` is the first part of the `parameter-value`. It names a specific software component, like `GCC` or `CMake`. If the `parameter` names a specific software, `parameter` and `value-name` are equal. - **value-version** (`packaging.version.Version`): The `value-version` is the second part of the `parameter-value`. It defines a specific version of a software component such like `12.2` or `3.12`. - **parameter-value-list** (`parameter: str = List[parameter-value: Tuple[value-name: str, value-version: packaging.version.Version]]`): A `parameter-value-list` is a list of `parameter-values` assigned to a `parameter`. For example: @@ -16,4 +16,5 @@ The real Python types are implemented in [types.py](../bashi/types.py) - **parameter-value-tuple** (`OrderedList[parameter: str, parameter-value: Tuple[value-name: str, value-version: packaging.version.Version]]`): A `parameter-value-tuple` is a list of one ore more `parameter-value`s. The `parameter-value-tuple` is created from a `parameter-value-matrix` and each `parameter-value` is assigned to a different `parameter`. This means, each `parameter-value` is from a different `parameter-value-list` in a `parameter-value-matrix`. The `parameter-value-tuple` has the same or a smaller number of entries as the number of `parameters` in a `parameter-value-matrix`. - **combination** (`OrderedList[parameter: str, parameter-value: Tuple[value-name: str, value-version: packaging.version.Version]]`): A `combination` is a `parameter-value-tuple` with the same number of `parameter-value`s as the number of input `parameters`. - **combination-list** (`List[OrderedList[parameter : str, parameter-value: Tuple[value-name: str, value-version: packaging.version.Version]]]`): A `combination-list` is a list of `combination`s an the result of the pair-wise generator. -- **parameter-value-pair** (`OrderedList[parameter: str, parameter-value: Tuple[value-name: str, value-version: packaging.version.Version]]`): A `parameter-value-pair` is a `parameter-value-tuple` with exact two `parameter-values`. The pair-wise generator guaranties that each `parameter-value-pair`, which can be created by the given `parameter-value-matrix` exists at least in one `combination` of the `combination-list`. The only exception is, if a `parameter-value-pair` is forbidden by a filter rule. +- **parameter-value-single** (`NamedTuple[parameter: str, parameter-value: NamedTuple[value-name: str, value-version: packaging.version.Version]]`): A `parameter-value-single` connects a `parameter` with a single `parameter-value`. +- **parameter-value-pair** (`NamedTuple[first: NamedTuple[parameter: str, parameter-value: NamedTuple[value-name: str, value-version: packaging.version.Version]], second: NamedTuple[parameter: str, parameter-value: NamedTuple[value-name: str, value-version: packaging.version.Version]]]`): A `parameter-value-pair` is a `parameter-value-tuple` with exact two `parameter-values`. The pair-wise generator guaranties that each `parameter-value-pair`, which can be created by the given `parameter-value-matrix` exists at least in one `combination` of the `combination-list`. The only exception is, if a `parameter-value-pair` is forbidden by a filter rule. diff --git a/tests/test_expected_parameter_value_pairs.py b/tests/test_expected_parameter_value_pairs.py index 1e2654f..3b1b6d3 100644 --- a/tests/test_expected_parameter_value_pairs.py +++ b/tests/test_expected_parameter_value_pairs.py @@ -1,16 +1,19 @@ # pylint: disable=missing-docstring import unittest -from typing import List +from typing import List, Dict from collections import OrderedDict import io +import packaging.version as pkv # allpairspy has no type hints from allpairspy import AllPairs # type: ignore from utils_test import parse_param_val, parse_param_vals, parse_expected_val_pairs from bashi.types import ( + Parameter, + ParameterValue, + ParameterValueSingle, ParameterValuePair, ParameterValueMatrix, - Combination, CombinationList, ) from bashi.globals import * # pylint: disable=wildcard-import,unused-wildcard-import @@ -18,9 +21,162 @@ get_expected_parameter_value_pairs, check_parameter_value_pair_in_combination_list, remove_parameter_value_pair, + create_parameter_value_pair, ) +class TestCreateParameterValuePair(unittest.TestCase): + def test_create_parameter_value_pair_type_str(self): + param1 = "param1" + param2 = "param2" + name1 = "name1" + name2 = "name2" + ver1 = "1.0" + ver2 = "2.0" + + created_param_val_pair = create_parameter_value_pair( + param1, name1, ver1, param2, name2, ver2 + ) + + expected_param_val_pair = ParameterValuePair( + ParameterValueSingle( + param1, + ParameterValue(name1, pkv.parse(ver1)), + ), + ParameterValueSingle( + param2, + ParameterValue(name2, pkv.parse(ver2)), + ), + ) + + self.assertEqual(type(created_param_val_pair), type(expected_param_val_pair)) + + self.assertEqual( + created_param_val_pair, + expected_param_val_pair, + ) + + def test_create_parameter_value_pair_type_float(self): + param1 = "param1" + param2 = "param2" + name1 = "name1" + name2 = "name2" + ver1 = 1.0 + ver2 = 2.0 + + created_param_val_pair = create_parameter_value_pair( + param1, name1, ver1, param2, name2, ver2 + ) + + expected_param_val_pair = ParameterValuePair( + ParameterValueSingle( + param1, + ParameterValue(name1, pkv.parse(str(ver1))), + ), + ParameterValueSingle( + param2, + ParameterValue(name2, pkv.parse(str(ver2))), + ), + ) + + self.assertEqual(type(created_param_val_pair), type(expected_param_val_pair)) + + self.assertEqual( + created_param_val_pair, + expected_param_val_pair, + ) + + def test_create_parameter_value_pair_type_int(self): + param1 = "param1" + param2 = "param2" + name1 = "name1" + name2 = "name2" + ver1 = 1 + ver2 = 2 + + created_param_val_pair = create_parameter_value_pair( + param1, name1, ver1, param2, name2, ver2 + ) + + expected_param_val_pair = ParameterValuePair( + ParameterValueSingle( + param1, + ParameterValue(name1, pkv.parse(str(ver1))), + ), + ParameterValueSingle( + param2, + ParameterValue(name2, pkv.parse(str(ver2))), + ), + ) + + self.assertEqual(type(created_param_val_pair), type(expected_param_val_pair)) + + self.assertEqual( + created_param_val_pair, + expected_param_val_pair, + ) + + def test_create_parameter_value_pair_type_version(self): + param1 = "param1" + param2 = "param2" + name1 = "name1" + name2 = "name2" + ver1 = pkv.parse("1.0") + ver2 = pkv.parse("2.0") + + created_param_val_pair = create_parameter_value_pair( + param1, name1, ver1, param2, name2, ver2 + ) + + expected_param_val_pair = ParameterValuePair( + ParameterValueSingle( + param1, + ParameterValue(name1, ver1), + ), + ParameterValueSingle( + param2, + ParameterValue(name2, ver2), + ), + ) + + self.assertEqual(type(created_param_val_pair), type(expected_param_val_pair)) + + self.assertEqual( + created_param_val_pair, + expected_param_val_pair, + ) + + def test_create_parameter_value_pair_type_mixed(self): + param1 = "param1" + param2 = "param2" + name1 = "name1" + name2 = "name2" + ver1 = pkv.parse("1.0") + ver2 = "2.0" + + created_param_val_pair = create_parameter_value_pair( + param1, name1, ver1, param2, name2, ver2 + ) + + expected_param_val_pair = ParameterValuePair( + ParameterValueSingle( + param1, + ParameterValue(name1, ver1), + ), + ParameterValueSingle( + param2, + ParameterValue(name2, pkv.parse(ver2)), + ), + ) + + self.assertEqual(type(created_param_val_pair), type(expected_param_val_pair)) + + self.assertEqual( + created_param_val_pair, + expected_param_val_pair, + ) + + class TestExpectedValuePairs(unittest.TestCase): @classmethod def setUpClass(cls): @@ -35,9 +191,9 @@ def setUpClass(cls): cls.param_matrix[CMAKE] = parse_param_vals([(CMAKE, 3.22), (CMAKE, 3.23)]) cls.param_matrix[BOOST] = parse_param_vals([(BOOST, 1.81), (BOOST, 1.82), (BOOST, 1.83)]) - cls.generated_parameter_value_pairs: List[ - ParameterValuePair - ] = get_expected_parameter_value_pairs(cls.param_matrix) + cls.generated_parameter_value_pairs: List[ParameterValuePair] = ( + get_expected_parameter_value_pairs(cls.param_matrix) + ) OD = OrderedDict @@ -348,9 +504,8 @@ def test_unrestricted_allpairspy_generator(self): class TestRemoveExpectedParameterValuePair(unittest.TestCase): - def test_remove_two_entry_parameter_value_pair(self): + def test_remove_parameter_value_pair(self): OD = OrderedDict - ppv = parse_param_val expected_param_value_pairs: List[ParameterValuePair] = parse_expected_val_pairs( [ @@ -362,29 +517,9 @@ def test_remove_two_entry_parameter_value_pair(self): ) original_length = len(expected_param_value_pairs) - # expects one or two entries - self.assertRaises( - RuntimeError, - remove_parameter_value_pair, - OD(), - expected_param_value_pairs, - ) - self.assertRaises( - RuntimeError, - remove_parameter_value_pair, - OD( - { - HOST_COMPILER: ppv((GCC, 9)), - DEVICE_COMPILER: ppv((NVCC, 11.2)), - CMAKE: ppv((CMAKE, 3.23)), - } - ), - expected_param_value_pairs, - ) - self.assertFalse( remove_parameter_value_pair( - OD({HOST_COMPILER: ppv((GCC, 9)), DEVICE_COMPILER: ppv((NVCC, 11.2))}), + create_parameter_value_pair(HOST_COMPILER, GCC, 9, DEVICE_COMPILER, NVCC, 11.2), expected_param_value_pairs, ) ) @@ -392,7 +527,7 @@ def test_remove_two_entry_parameter_value_pair(self): self.assertTrue( remove_parameter_value_pair( - OD({HOST_COMPILER: ppv((GCC, 10)), DEVICE_COMPILER: ppv((NVCC, 12.0))}), + create_parameter_value_pair(HOST_COMPILER, GCC, 10, DEVICE_COMPILER, NVCC, 12.0), expected_param_value_pairs, ) ) @@ -400,13 +535,13 @@ def test_remove_two_entry_parameter_value_pair(self): self.assertTrue( remove_parameter_value_pair( - OD({CMAKE: ppv((CMAKE, 3.23)), BOOST: ppv((BOOST, 1.83))}), + create_parameter_value_pair(CMAKE, CMAKE, 3.23, BOOST, BOOST, 1.83), expected_param_value_pairs, ) ) self.assertEqual(len(expected_param_value_pairs), original_length - 2) - def test_remove_single_entry_parameter_value_pair(self): + def test_remove_parameter_value_single(self): OD = OrderedDict ppv = parse_param_val @@ -421,9 +556,18 @@ def test_remove_single_entry_parameter_value_pair(self): ) original_length = len(expected_param_value_pairs) + # all_versions=True is not support for ParameterValueSingle + self.assertRaises( + RuntimeError, + remove_parameter_value_pair, + ParameterValueSingle(HOST_COMPILER, ppv((GCC, 12))), + expected_param_value_pairs, + True, + ) + self.assertFalse( remove_parameter_value_pair( - OD({HOST_COMPILER: ppv((GCC, 12))}), + ParameterValueSingle(HOST_COMPILER, ppv((GCC, 12))), expected_param_value_pairs, ) ) @@ -431,7 +575,7 @@ def test_remove_single_entry_parameter_value_pair(self): self.assertFalse( remove_parameter_value_pair( - OD({HOST_COMPILER: ppv((CLANG, 12))}), + ParameterValueSingle(HOST_COMPILER, ppv((CLANG, 12))), expected_param_value_pairs, ) ) @@ -439,7 +583,7 @@ def test_remove_single_entry_parameter_value_pair(self): self.assertFalse( remove_parameter_value_pair( - OD({UBUNTU: ppv((UBUNTU, 20.04))}), + ParameterValueSingle(UBUNTU, ppv((UBUNTU, 20.04))), expected_param_value_pairs, ) ) @@ -447,7 +591,7 @@ def test_remove_single_entry_parameter_value_pair(self): self.assertTrue( remove_parameter_value_pair( - OD({HOST_COMPILER: ppv((GCC, 9))}), + ParameterValueSingle(HOST_COMPILER, ppv((GCC, 9))), expected_param_value_pairs, ) ) @@ -462,3 +606,143 @@ def test_remove_single_entry_parameter_value_pair(self): ] ), ) + + def test_remove_parameter_value_pair_all_versions(self): + versions = { + GCC: [9, 10, 11, 12, 13], + CLANG: [13, 14, 15, 16, 17], + NVCC: [11.0, 11.1, 11.2, 11.3, 11.4], + HIPCC: [5.0, 5.1, 5.2, 5.3], + CMAKE: [3.22, 3.23, 3.24], + BOOST: [1.80, 1.81, 1.82], + } + + param_val_matrix: ParameterValueMatrix = OrderedDict() + for compiler in [HOST_COMPILER, DEVICE_COMPILER]: + param_val_matrix[compiler] = [] + for compiler_name in [GCC, CLANG, NVCC, HIPCC]: + for compiler_version in versions[compiler_name]: + param_val_matrix[compiler].append( + ParameterValue(compiler_name, pkv.parse(str(compiler_version))) + ) + + for sw in [CMAKE, BOOST]: + param_val_matrix[sw] = [] + for version in versions[sw]: + param_val_matrix[sw].append(ParameterValue(sw, pkv.parse(str(version)))) + + reduced_param_value_pairs = get_expected_parameter_value_pairs(param_val_matrix) + + expected_number_of_reduced_pairs = len(reduced_param_value_pairs) + + expected_reduced_param_value_pairs = reduced_param_value_pairs.copy() + + # remove single value to verify that default flag is working + example_single_pair = create_parameter_value_pair( + HOST_COMPILER, + NVCC, + 11.0, + DEVICE_COMPILER, + NVCC, + 11.3, + ) + + expected_reduced_param_value_pairs.remove(example_single_pair) + + self.assertTrue( + remove_parameter_value_pair( + to_remove=example_single_pair, + parameter_value_pairs=reduced_param_value_pairs, + ) + ) + + # remove single entry + expected_number_of_reduced_pairs -= 1 + self.assertEqual(len(reduced_param_value_pairs), expected_number_of_reduced_pairs) + + reduced_param_value_pairs.sort() + expected_reduced_param_value_pairs.sort() + self.assertEqual(reduced_param_value_pairs, expected_reduced_param_value_pairs) + + # remove all expected tuples, where host and device compiler is nvcc + def filter_function1(param_val_pair: ParameterValuePair) -> bool: + if ( + param_val_pair.first.parameter == HOST_COMPILER + and param_val_pair.second.parameter == DEVICE_COMPILER + ): + if ( + param_val_pair.first.parameterValue.name == NVCC + and param_val_pair.second.parameterValue.name == NVCC + ): + return False + + return True + + expected_reduced_param_value_pairs[:] = list( + filter(filter_function1, expected_reduced_param_value_pairs) + ) + + self.assertTrue( + remove_parameter_value_pair( + to_remove=create_parameter_value_pair( + HOST_COMPILER, + NVCC, + 0, + DEVICE_COMPILER, + NVCC, + 0, + ), + parameter_value_pairs=reduced_param_value_pairs, + all_versions=True, + ) + ) + + # remove number of pairs, where host and device compiler is nvcc + # -1 because we removed already a combination manually before + expected_number_of_reduced_pairs -= len(versions[NVCC]) * len(versions[NVCC]) - 1 + self.assertEqual(len(reduced_param_value_pairs), expected_number_of_reduced_pairs) + + reduced_param_value_pairs.sort() + expected_reduced_param_value_pairs.sort() + self.assertEqual(reduced_param_value_pairs, expected_reduced_param_value_pairs) + + # remove all combinations where HIPCC is the host compiler and nvcc the device compiler + def filter_function2(param_val_pair: ParameterValuePair) -> bool: + if ( + param_val_pair.first.parameter == HOST_COMPILER + and param_val_pair.second.parameter == DEVICE_COMPILER + ): + if ( + param_val_pair.first.parameterValue.name == HIPCC + and param_val_pair.second.parameterValue.name == NVCC + ): + return False + + return True + + expected_reduced_param_value_pairs[:] = list( + filter(filter_function2, expected_reduced_param_value_pairs) + ) + + self.assertTrue( + remove_parameter_value_pair( + to_remove=create_parameter_value_pair( + HOST_COMPILER, + HIPCC, + 0, + DEVICE_COMPILER, + NVCC, + 0, + ), + parameter_value_pairs=reduced_param_value_pairs, + all_versions=True, + ) + ) + + # remove number pairs, where host compiler is HIPCC and device compiler is nvcc + expected_number_of_reduced_pairs -= len(versions[HIPCC]) * len(versions[NVCC]) + self.assertEqual(len(reduced_param_value_pairs), expected_number_of_reduced_pairs) + + reduced_param_value_pairs.sort() + expected_reduced_param_value_pairs.sort() + self.assertEqual(reduced_param_value_pairs, expected_reduced_param_value_pairs) diff --git a/tests/test_generate_combination_list.py b/tests/test_generate_combination_list.py index c130095..988f7b5 100644 --- a/tests/test_generate_combination_list.py +++ b/tests/test_generate_combination_list.py @@ -11,9 +11,11 @@ get_expected_parameter_value_pairs, check_parameter_value_pair_in_combination_list, remove_parameter_value_pair, + create_parameter_value_pair, ) from bashi.types import ( ParameterValue, + ParameterValueSingle, ParameterValuePair, ParameterValueTuple, ParameterValueMatrix, @@ -35,9 +37,9 @@ def setUpClass(cls): cls.param_matrix[CMAKE] = parse_param_vals([(CMAKE, 3.22), (CMAKE, 3.23)]) cls.param_matrix[BOOST] = parse_param_vals([(BOOST, 1.81), (BOOST, 1.82), (BOOST, 1.83)]) - cls.generated_parameter_value_pairs: List[ - ParameterValuePair - ] = get_expected_parameter_value_pairs(cls.param_matrix) + cls.generated_parameter_value_pairs: List[ParameterValuePair] = ( + get_expected_parameter_value_pairs(cls.param_matrix) + ) def test_generator_without_custom_filter(self): comb_list = generate_combination_list(self.param_matrix) @@ -143,8 +145,8 @@ def custom_filter(row: ParameterValueTuple) -> bool: if device_compiler.name == NVCC: self.assertTrue( remove_parameter_value_pair( - OrderedDict( - {DEVICE_COMPILER: ParameterValue(NVCC, device_compiler.version)} + ParameterValueSingle( + DEVICE_COMPILER, ParameterValue(NVCC, device_compiler.version) ), reduced_expected_param_val_pairs, ) @@ -152,11 +154,13 @@ def custom_filter(row: ParameterValueTuple) -> bool: self.assertTrue( remove_parameter_value_pair( - OrderedDict( - { - CMAKE: ParameterValue(CMAKE, pkv.parse("3.23")), - BOOST: ParameterValue(BOOST, pkv.parse("1.82")), - }, + create_parameter_value_pair( + CMAKE, + CMAKE, + "3.23", + BOOST, + BOOST, + "1.82", ), reduced_expected_param_val_pairs, ) @@ -213,12 +217,7 @@ def custom_filter(row: ParameterValueTuple) -> bool: self.assertTrue( remove_parameter_value_pair( - OrderedDict( - { - CMAKE: ParameterValue(CMAKE, pkv.parse("3.23")), - BOOST: ParameterValue(BOOST, pkv.parse("1.82")), - }, - ), + create_parameter_value_pair(CMAKE, CMAKE, "3.23", BOOST, BOOST, "1.82"), reduced_expected_param_val_pairs, ) ) diff --git a/tests/utils_test.py b/tests/utils_test.py index 74182fb..f6753f4 100644 --- a/tests/utils_test.py +++ b/tests/utils_test.py @@ -9,6 +9,7 @@ ParameterValuePair, ValueName, ) +from bashi.utils import create_parameter_value_pair def parse_param_val(param_val: Tuple[ValueName, Union[str, int, float]]) -> ParameterValue: @@ -58,9 +59,24 @@ def parse_expected_val_pairs( expected_val_pairs: List[ParameterValuePair] = [] for param_val_pair in input_list: - tmp_entry: ParameterValuePair = OrderedDict() - for param in param_val_pair: - tmp_entry[param] = parse_param_val(param_val_pair[param]) - expected_val_pairs.append(tmp_entry) + if len(param_val_pair) != 2: + raise RuntimeError("input_list needs to have two entries") + + it = iter(param_val_pair.items()) + param1, param_val1 = next(it) + val_name1, val_version1 = param_val1 + param2, param_val2 = next(it) + val_name2, val_version2 = param_val2 + + expected_val_pairs.append( + create_parameter_value_pair( + param1, + val_name1, + val_version1, + param2, + val_name2, + val_version2, + ) + ) return expected_val_pairs