diff --git a/elastica/modules/feature_group.py b/elastica/modules/feature_group.py index 9578f954..55970a14 100644 --- a/elastica/modules/feature_group.py +++ b/elastica/modules/feature_group.py @@ -6,21 +6,50 @@ class FeatureGroupFIFO(Iterable): + """ + A class to store the features and their corresponding operators in a FIFO manner. + + Examples + -------- + >>> feature_group = FeatureGroupFIFO() + >>> feature_group.append_id(obj_1) + >>> feature_group.append_id(obj_2) + >>> feature_group.add_operators(obj_1, [OperatorType.ADD, OperatorType.SUBTRACT]) + >>> feature_group.add_operators(obj_2, [OperatorType.SUBTRACT, OperatorType.MULTIPLY]) + >>> list(feature_group) + [OperatorType.ADD, OperatorType.SUBTRACT, OperatorType.SUBTRACT, OperatorType.MULTIPLY] + + Attributes + ---------- + _operator_collection : list[list[OperatorType]] + A list of lists of operators. Each list of operators corresponds to a feature. + _operator_ids : list[int] + A list of ids of the features. + + Methods + ------- + append_id(feature) + Appends the id of the feature to the list of ids. + add_operators(feature, operators) + Adds the operators to the list of operators corresponding to the feature. + """ + def __init__(self): self._operator_collection: list[list[OperatorType]] = [] self._operator_ids: list[int] = [] def __iter__(self) -> OperatorType: - if not self._operator_collection: - raise RuntimeError("Feature group is not instantiated.") + """Returns an operator iterator to satisfy the Iterable protocol.""" operator_chain = itertools.chain.from_iterable(self._operator_collection) for operator in operator_chain: yield operator def append_id(self, feature): + """Appends the id of the feature to the list of ids.""" self._operator_ids.append(id(feature)) self._operator_collection.append([]) def add_operators(self, feature, operators: list[OperatorType]): + """Adds the operators to the list of operators corresponding to the feature.""" idx = self._operator_ids.index(id(feature)) self._operator_collection[idx].extend(operators) diff --git a/tests/test_modules/test_feature_grouping.py b/tests/test_modules/test_feature_grouping.py new file mode 100644 index 00000000..04c0d521 --- /dev/null +++ b/tests/test_modules/test_feature_grouping.py @@ -0,0 +1,58 @@ +import pytest + +from elastica.modules.feature_group import FeatureGroupFIFO + + +def test_add_ids(): + feature_group = FeatureGroupFIFO() + feature_group.append_id(1) + feature_group.append_id(2) + feature_group.append_id(3) + + assert feature_group._operator_ids == [id(1), id(2), id(3)] + + +def test_add_operators(): + feature_group = FeatureGroupFIFO() + feature_group.append_id(1) + feature_group.add_operators(1, [1, 2, 3]) + feature_group.append_id(2) + feature_group.add_operators(2, [4, 5, 6]) + feature_group.append_id(3) + feature_group.add_operators(3, [7, 8, 9]) + + assert feature_group._operator_collection == [[1, 2, 3], [4, 5, 6], [7, 8, 9]] + assert feature_group._operator_ids == [id(1), id(2), id(3)] + + feature_group.append_id(4) + feature_group.add_operators(4, [10, 11, 12]) + + assert feature_group._operator_collection == [ + [1, 2, 3], + [4, 5, 6], + [7, 8, 9], + [10, 11, 12], + ] + assert feature_group._operator_ids == [id(1), id(2), id(3), id(4)] + + +def test_grouping(): + feature_group = FeatureGroupFIFO() + feature_group.append_id(1) + feature_group.add_operators(1, [1, 2, 3]) + feature_group.append_id(2) + feature_group.add_operators(2, [4, 5, 6]) + feature_group.append_id(3) + feature_group.add_operators(3, [7, 8, 9]) + + assert list(feature_group) == [1, 2, 3, 4, 5, 6, 7, 8, 9] + + feature_group.append_id(4) + feature_group.add_operators(4, [10, 11, 12]) + + assert list(feature_group) == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] + + feature_group.append_id(1) + feature_group.add_operators(1, [13, 14, 15]) + + assert list(feature_group) == [1, 2, 3, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11, 12]