Skip to content

Commit

Permalink
Merge pull request #5 from SimeonEhrig/FilterAdapter
Browse files Browse the repository at this point in the history
implement the FilterAdapter
  • Loading branch information
SimeonEhrig authored Jan 24, 2024
2 parents 6d0f9fe + dda7a6d commit af21929
Show file tree
Hide file tree
Showing 4 changed files with 242 additions and 3 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/testDeploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
needs: formatter
strategy:
matrix:
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
python-version: ['3.9', '3.10', '3.11', '3.12']
name: Run unit tests with Python ${{ matrix.python-version }}
steps:
- uses: actions/checkout@v4
Expand Down
73 changes: 73 additions & 0 deletions bashi/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""Different helper functions for bashi"""

from typing import Dict, Callable, Tuple, List
from collections import OrderedDict
from packaging.version import Version
from typeguard import typechecked


class FilterAdapter:
"""
An adapter for the filter functions used by allpairspy to provide a better filter function
interface.
Independent of the type of `parameter` (in the bashi naming convention:
parameter-value-matrix type) used as an argument of AllPairs.__init__(), allpairspy always
passes the same row type to the filter function: List of parameter-values.
Therefore, the parameter name is encoded in the position in the row list. This makes it
much more difficult to write filter rules.
The FilterAdapter transforms the list of parameter values into a parameter-value-tuple, which
has the type OrderedDict[str, Tuple[str, Version]].
This user writes a filter rule function with the expected line type
OrderedDict[str, Tuple[str, Version]], creates a FunctionAdapter object with the functor as a
parameter and passes the FunctionAdapter object to AllPairs.__init__().
filter function example:
def filter_function(row: OrderedDict[str, Tuple[str, Version]]):
if (
DEVICE_COMPILER in row
and row[DEVICE_COMPILER][NAME] == NVCC
and row[DEVICE_COMPILER][VERSION] < pkv.parse("12.0")
):
return False
return True
"""

@typechecked
def __init__(
self,
param_map: Dict[int, str],
filter_func: Callable[[OrderedDict[str, Tuple[str, Version]]], bool],
):
"""Create a new FilterAdapter, see class doc string.
Args:
param_map (Dict[int, str]): The param_map maps the index position of a parameter to the
parameter name. Assuming the parameter-value-matrix has the following keys:
["param1", "param2", "param3"], the param_map should look like this:
{0: "param1", 1 : "param2", 2 : "param3"}.
filter_func (Callable[[OrderedDict[str, Tuple[str, Version]]], bool]): The filter
function used by allpairspy, see class doc string.
"""
self.param_map = param_map
self.filter_func = filter_func

def __call__(self, row: List[Tuple[str, Version]]) -> bool:
"""The expected interface of allpairspy filter rule.
Transform the type of row from List[Tuple[str, Version]] to
[OrderedDict[str, Tuple[str, Version]]].
Args:
row (List[Tuple[str, Version]]): the parameter-value-tuple
Returns:
bool: Returns True, if the parameter-value-tuple is valid
"""
ordered_row: OrderedDict[str, Tuple[str, Version]] = OrderedDict()
for index, param_name in enumerate(row):
ordered_row[self.param_map[index]] = param_name
return self.filter_func(ordered_row)
8 changes: 6 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ authors = [
{name = "Jan Stephan", email = "[email protected]"},
]
readme = "README.md"
requires-python = ">=3.8"
requires-python = ">=3.9"
license = {file = "LICENSE"}
description = "The library provides everything needed to generate a sparse combination matrix for alpaka-based projects, including a set of general-purpose combination rules."
dynamic = ["version"]
Expand All @@ -24,7 +24,8 @@ classifiers= [
]
dependencies = [
"allpairspy == 2.5.1",
"typeguard"
"typeguard",
"packaging"
]

[project.scripts]
Expand All @@ -44,3 +45,6 @@ Issues = "https://github.com/alpaka-group/bashi/issues"
command_line = "-m unittest discover -s tests/"
branch = true
source = ["bashi"]

[tool.black]
line-length = 100
162 changes: 162 additions & 0 deletions tests/test_filter_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# pylint: disable=missing-docstring
import unittest
from typing import Tuple, Dict, List
from collections import OrderedDict
from packaging.version import Version
import packaging.version as pkv
from typeguard import typechecked
from bashi.utils import FilterAdapter


class TestFilterAdapterDataSet1(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.param_val_tuple: OrderedDict[str, Tuple[str, Version]] = OrderedDict()
cls.param_val_tuple["param1"] = ("param-val-name1", pkv.parse("1"))
cls.param_val_tuple["param2"] = ("param-val-name2", pkv.parse("2"))
cls.param_val_tuple["param3"] = ("param-val-name3", pkv.parse("3"))

cls.param_map: Dict[int, str] = {}
for index, param_name in enumerate(cls.param_val_tuple.keys()):
cls.param_map[index] = param_name

cls.test_row: List[Tuple[str, Version]] = []
for param_val in cls.param_val_tuple.values():
cls.test_row.append(param_val)

# use typechecked to do a deep type check
# isinstance() only verify the "outer" data type, which is OrderedDict
# isinstance() does not verify the key and value type
def test_function_type(self):
@typechecked
def filter_function(row: OrderedDict[str, Tuple[str, Version]]) -> bool:
if len(row.keys()) < 1:
raise AssertionError("There is no element in row.")

# typechecked does not check the types of Tuple, therefore I "unwrap" it
@typechecked
def check_param_value_type(_: Tuple[str, Version]):
pass

check_param_value_type(next(iter(row.values())))

return True

filter_adapter = FilterAdapter(self.param_map, filter_function)
self.assertTrue(filter_adapter(self.test_row))

def test_function_length(self):
def filter_function(row: OrderedDict[str, Tuple[str, Version]]) -> bool:
if len(row) != 3:
raise AssertionError(f"Size of test_row is {len(row)}. Expected is 3.")

return True

filter_adapter = FilterAdapter(self.param_map, filter_function)
self.assertTrue(filter_adapter(self.test_row))

def test_function_row_order(self):
def filter_function(row: OrderedDict[str, Tuple[str, Version]]) -> bool:
excepted_param_order = ["param1", "param2", "param3"]
if len(excepted_param_order) != len(row):
raise AssertionError(
"excepted_key_order and row has not the same length.\n"
f"{len(excepted_param_order)} != {len(row)}"
)

for index, param in enumerate(row.keys()):
if excepted_param_order[index] != param:
raise AssertionError(
f"The {index}. parameter is not the expected "
f"parameter: {excepted_param_order[index]}"
)

expected_param_value_order = [
("param-val-name1", pkv.parse("1")),
("param-val-name2", pkv.parse("2")),
("param-val-name3", pkv.parse("3")),
]

for index, param_value in enumerate(row.values()):
expected_value_name = expected_param_value_order[index][0]
expected_value_version = expected_param_value_order[index][1]
if (
expected_value_name != param_value[0]
or expected_value_version != param_value[1]
):
raise AssertionError(
f"The {index}. parameter-value is not the expected parameter-value\n"
f"Get: {param_value}\n"
f"Expected: {expected_param_value_order[index]}"
)

return True

filter_adapter = FilterAdapter(self.param_map, filter_function)
self.assertTrue(filter_adapter(self.test_row))

def test_lambda(self):
filter_adapter = FilterAdapter(self.param_map, lambda row: len(row) == 3)
self.assertTrue(filter_adapter(self.test_row), "row has not the length of 3")


# do a complex test with a different data set
class TestFilterAdapterDataSet2(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.param_val_tuple: OrderedDict[str, Tuple[str, Version]] = OrderedDict()
cls.param_val_tuple["param6b"] = ("param-val-name1", pkv.parse("3.21.2"))
cls.param_val_tuple["param231a"] = ("param-val-name67asd", pkv.parse("2.4"))
cls.param_val_tuple["param234s"] = ("param-val-678", pkv.parse("3"))
cls.param_val_tuple["foo"] = ("foo", pkv.parse("12.3"))
cls.param_val_tuple["bar"] = ("bar", pkv.parse("3"))

cls.param_map: Dict[int, str] = {}
for index, param_name in enumerate(cls.param_val_tuple.keys()):
cls.param_map[index] = param_name

cls.test_row: List[Tuple[str, Version]] = []
for param_val in cls.param_val_tuple.values():
cls.test_row.append(param_val)

def test_function_row_lenght_order(self):
def filter_function(row: OrderedDict[str, Tuple[str, Version]]) -> bool:
excepted_param_order = ["param6b", "param231a", "param234s", "foo", "bar"]
if len(excepted_param_order) != len(row):
raise AssertionError(
"excepted_key_order and row has not the same length.\n"
f"{len(excepted_param_order)} != {len(row)}"
)

for index, param in enumerate(row.keys()):
if excepted_param_order[index] != param:
raise AssertionError(
f"The {index}. parameter is not the expected "
f"parameter: {excepted_param_order[index]}"
)

expected_param_value_order = [
("param-val-name1", pkv.parse("3.21.2")),
("param-val-name67asd", pkv.parse("2.4")),
("param-val-678", pkv.parse("3")),
("foo", pkv.parse("12.3")),
("bar", pkv.parse("3")),
]

for index, param_value in enumerate(row.values()):
expected_value_name = expected_param_value_order[index][0]
expected_value_version = expected_param_value_order[index][1]
if (
expected_value_name != param_value[0]
or expected_value_version != param_value[1]
):
raise AssertionError(
f"The {index}. parameter-value is not the expected parameter-value\n"
f"Get: {param_value}\n"
f"Expected: {expected_param_value_order[index]}"
)

return True

filter_adapter = FilterAdapter(self.param_map, filter_function)
self.assertTrue(filter_adapter(self.test_row))

0 comments on commit af21929

Please sign in to comment.