Skip to content

Commit

Permalink
🔮 global scope parameter passing logic
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurenzBeck committed Jan 8, 2025
1 parent c3b836e commit 7650577
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 28 deletions.
1 change: 1 addition & 0 deletions changelog/11.added
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
🔮 global scope parameter passing logic.
53 changes: 26 additions & 27 deletions examples/time series classification/01-static-distributions.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@
" This ensures that the structure of the tree is preserved (Otherwise we would create a more generic directed acyclic graph),\n",
" which is not supported by `anytree`.\n",
"\n",
"The parameters of the functions inside a tree are stored in a `ParameterStore` object. A sampling tree knows which parameters need to be passed to the functions based on **scopes**.\n",
"The parameters of the functions inside a tree are stored in a `ParameterStore` object. A sampling tree fetches arguments automatically from this store. To avoid naming conflicts and to have more control, you can also use the transform names as **scopes**. Scoped parameters have precedence over top-level parameters.\n",
"\n",
"> 👉 more details about parameter handling is explained in the next notebook on [🌌 data streams](./02-data-streams.ipynb)."
]
Expand Down Expand Up @@ -230,12 +230,11 @@
" },\n",
" \"ramp\": {\n",
" \"height\": 1.0,\n",
" \"length\": 128,\n",
" },\n",
" \"step\": {\n",
" \"length\": 128,\n",
" \"kernel_size\": 10,\n",
" },\n",
" \"length\": 128, # top-level length parameter will be passed to both the `step` and `ramp` functions\n",
"}\n",
"\n",
"# 🎲🌳 tree of transformations\n",
Expand Down Expand Up @@ -528,7 +527,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c4f8ea7b30ef454591290a61689fdb35",
"model_id": "890b54ed941c4472a3c4790047e0431d",
"version_major": 2,
"version_minor": 0
},
Expand Down Expand Up @@ -939,42 +938,42 @@
"</style>\n",
"\n",
"<div class=\"animation\">\n",
" <img id=\"_anim_imgb27a23862fdf45fb928ec38c9af03191\">\n",
" <img id=\"_anim_img7204dcaa93b04402a6b8842c19d81b91\">\n",
" <div class=\"anim-controls\">\n",
" <input id=\"_anim_sliderb27a23862fdf45fb928ec38c9af03191\" type=\"range\" class=\"anim-slider\"\n",
" <input id=\"_anim_slider7204dcaa93b04402a6b8842c19d81b91\" type=\"range\" class=\"anim-slider\"\n",
" name=\"points\" min=\"0\" max=\"1\" step=\"1\" value=\"0\"\n",
" oninput=\"animb27a23862fdf45fb928ec38c9af03191.set_frame(parseInt(this.value));\">\n",
" oninput=\"anim7204dcaa93b04402a6b8842c19d81b91.set_frame(parseInt(this.value));\">\n",
" <div class=\"anim-buttons\">\n",
" <button title=\"Decrease speed\" aria-label=\"Decrease speed\" onclick=\"animb27a23862fdf45fb928ec38c9af03191.slower()\">\n",
" <button title=\"Decrease speed\" aria-label=\"Decrease speed\" onclick=\"anim7204dcaa93b04402a6b8842c19d81b91.slower()\">\n",
" <i class=\"fa fa-minus\"></i></button>\n",
" <button title=\"First frame\" aria-label=\"First frame\" onclick=\"animb27a23862fdf45fb928ec38c9af03191.first_frame()\">\n",
" <button title=\"First frame\" aria-label=\"First frame\" onclick=\"anim7204dcaa93b04402a6b8842c19d81b91.first_frame()\">\n",
" <i class=\"fa fa-fast-backward\"></i></button>\n",
" <button title=\"Previous frame\" aria-label=\"Previous frame\" onclick=\"animb27a23862fdf45fb928ec38c9af03191.previous_frame()\">\n",
" <button title=\"Previous frame\" aria-label=\"Previous frame\" onclick=\"anim7204dcaa93b04402a6b8842c19d81b91.previous_frame()\">\n",
" <i class=\"fa fa-step-backward\"></i></button>\n",
" <button title=\"Play backwards\" aria-label=\"Play backwards\" onclick=\"animb27a23862fdf45fb928ec38c9af03191.reverse_animation()\">\n",
" <button title=\"Play backwards\" aria-label=\"Play backwards\" onclick=\"anim7204dcaa93b04402a6b8842c19d81b91.reverse_animation()\">\n",
" <i class=\"fa fa-play fa-flip-horizontal\"></i></button>\n",
" <button title=\"Pause\" aria-label=\"Pause\" onclick=\"animb27a23862fdf45fb928ec38c9af03191.pause_animation()\">\n",
" <button title=\"Pause\" aria-label=\"Pause\" onclick=\"anim7204dcaa93b04402a6b8842c19d81b91.pause_animation()\">\n",
" <i class=\"fa fa-pause\"></i></button>\n",
" <button title=\"Play\" aria-label=\"Play\" onclick=\"animb27a23862fdf45fb928ec38c9af03191.play_animation()\">\n",
" <button title=\"Play\" aria-label=\"Play\" onclick=\"anim7204dcaa93b04402a6b8842c19d81b91.play_animation()\">\n",
" <i class=\"fa fa-play\"></i></button>\n",
" <button title=\"Next frame\" aria-label=\"Next frame\" onclick=\"animb27a23862fdf45fb928ec38c9af03191.next_frame()\">\n",
" <button title=\"Next frame\" aria-label=\"Next frame\" onclick=\"anim7204dcaa93b04402a6b8842c19d81b91.next_frame()\">\n",
" <i class=\"fa fa-step-forward\"></i></button>\n",
" <button title=\"Last frame\" aria-label=\"Last frame\" onclick=\"animb27a23862fdf45fb928ec38c9af03191.last_frame()\">\n",
" <button title=\"Last frame\" aria-label=\"Last frame\" onclick=\"anim7204dcaa93b04402a6b8842c19d81b91.last_frame()\">\n",
" <i class=\"fa fa-fast-forward\"></i></button>\n",
" <button title=\"Increase speed\" aria-label=\"Increase speed\" onclick=\"animb27a23862fdf45fb928ec38c9af03191.faster()\">\n",
" <button title=\"Increase speed\" aria-label=\"Increase speed\" onclick=\"anim7204dcaa93b04402a6b8842c19d81b91.faster()\">\n",
" <i class=\"fa fa-plus\"></i></button>\n",
" </div>\n",
" <form title=\"Repetition mode\" aria-label=\"Repetition mode\" action=\"#n\" name=\"_anim_loop_selectb27a23862fdf45fb928ec38c9af03191\"\n",
" <form title=\"Repetition mode\" aria-label=\"Repetition mode\" action=\"#n\" name=\"_anim_loop_select7204dcaa93b04402a6b8842c19d81b91\"\n",
" class=\"anim-state\">\n",
" <input type=\"radio\" name=\"state\" value=\"once\" id=\"_anim_radio1_b27a23862fdf45fb928ec38c9af03191\"\n",
" <input type=\"radio\" name=\"state\" value=\"once\" id=\"_anim_radio1_7204dcaa93b04402a6b8842c19d81b91\"\n",
" >\n",
" <label for=\"_anim_radio1_b27a23862fdf45fb928ec38c9af03191\">Once</label>\n",
" <input type=\"radio\" name=\"state\" value=\"loop\" id=\"_anim_radio2_b27a23862fdf45fb928ec38c9af03191\"\n",
" <label for=\"_anim_radio1_7204dcaa93b04402a6b8842c19d81b91\">Once</label>\n",
" <input type=\"radio\" name=\"state\" value=\"loop\" id=\"_anim_radio2_7204dcaa93b04402a6b8842c19d81b91\"\n",
" checked>\n",
" <label for=\"_anim_radio2_b27a23862fdf45fb928ec38c9af03191\">Loop</label>\n",
" <input type=\"radio\" name=\"state\" value=\"reflect\" id=\"_anim_radio3_b27a23862fdf45fb928ec38c9af03191\"\n",
" <label for=\"_anim_radio2_7204dcaa93b04402a6b8842c19d81b91\">Loop</label>\n",
" <input type=\"radio\" name=\"state\" value=\"reflect\" id=\"_anim_radio3_7204dcaa93b04402a6b8842c19d81b91\"\n",
" >\n",
" <label for=\"_anim_radio3_b27a23862fdf45fb928ec38c9af03191\">Reflect</label>\n",
" <label for=\"_anim_radio3_7204dcaa93b04402a6b8842c19d81b91\">Reflect</label>\n",
" </form>\n",
" </div>\n",
"</div>\n",
Expand All @@ -984,9 +983,9 @@
" /* Instantiate the Animation class. */\n",
" /* The IDs given should match those used in the template above. */\n",
" (function() {\n",
" var img_id = \"_anim_imgb27a23862fdf45fb928ec38c9af03191\";\n",
" var slider_id = \"_anim_sliderb27a23862fdf45fb928ec38c9af03191\";\n",
" var loop_select_id = \"_anim_loop_selectb27a23862fdf45fb928ec38c9af03191\";\n",
" var img_id = \"_anim_img7204dcaa93b04402a6b8842c19d81b91\";\n",
" var slider_id = \"_anim_slider7204dcaa93b04402a6b8842c19d81b91\";\n",
" var loop_select_id = \"_anim_loop_select7204dcaa93b04402a6b8842c19d81b91\";\n",
" var frames = new Array(8);\n",
" \n",
" frames[0] = \"\\\n",
Expand Down Expand Up @@ -6886,7 +6885,7 @@
" /* set a timeout to make sure all the above elements are created before\n",
" the object is initialized. */\n",
" setTimeout(function() {\n",
" animb27a23862fdf45fb928ec38c9af03191 = new Animation(frames, img_id, slider_id, 200.0,\n",
" anim7204dcaa93b04402a6b8842c19d81b91 = new Animation(frames, img_id, slider_id, 200.0,\n",
" loop_select_id);\n",
" }, 0);\n",
" })()\n",
Expand Down
16 changes: 15 additions & 1 deletion streamgen/nodes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""🪢 different node implementations using [anytree](https://anytree.readthedocs.io/en/stable/) `NodeMixin`."""

import inspect
from collections import deque
from collections.abc import Callable
from typing import Any, Protocol, runtime_checkable
Expand Down Expand Up @@ -112,16 +113,29 @@ def set_update_step(self, idx: int) -> None:
def fetch_params(self, params: ParameterStore) -> None:
"""⚙️ fetches params from a ParameterStore.
If the node was explicitly parameterized, use those params.
The parameters are fetched from both a matching scope and
the top-level/global scope with the scope having precedence.
Skips fetching if the node was explicitly parameterized.
Args:
params (ParameterStore): _description_
"""
if self.params:
return
self.params = ParameterStore()
if self.name in params.scopes:
self.params = params.get_scope(self.name)

# infer missing arguments that were not present in the scope of the transform
missing_arguments = [
param.name for param in inspect.signature(self.transform).parameters.values() if param.name not in self.params.parameter_names
]
# if those missing arguments are in the top-level scope, add those parameters
for param_name in missing_arguments:
if param_name in params.parameter_names:
self.params[param_name] = params[param_name]

def get_params(self) -> ParameterStore | None:
"""⚙️ returns current parameters.
Expand Down
45 changes: 45 additions & 0 deletions tests/unit/samplingtree_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ def add_random_points(input, num_points): # noqa: A002
def add(input: int, number): # noqa: A002
return input + number

def add_and_subtract(input: int, number, number2): # noqa: A002
return input + number - number2


# ---------------------------------------------------------------------------- #
# * fixtures #
Expand All @@ -48,6 +51,48 @@ def add(input: int, number): # noqa: A002
# * tests #
# ---------------------------------------------------------------------------- #

def test_parameter_fetching_from_global_scope():
"""Tests if nodes fetch their missing arguments from the top-level/global scope."""
tree = SamplingTree(
[
lambda input: 0, # noqa: ARG005
{
"probs": Parameter("probs", schedule=[[1.0, 0.0], [0.0, 1.0]]),
"1": [
add,
"one",
],
"2": [
TransformNode(add, name="two"),
add_and_subtract,
"two",
],
},
TransformNode(operate_on_index()(add), Parameter("number", 3)),
],
{
"two": {
"number": 3
},
"add_and_subtract": {
"number": 5
},
"number": 1,
"number2": 2
}
)

output, target = tree.sample()

assert output == 4, "The last `partial(add, 3)` transform should be connected to both branches."
assert target == "one"

tree.update()
output, target = tree.sample()

assert output == 9
assert target == "two"


def test_sampling_tree_decision_node_with_probs():
"""Tests the initialization, sampling and parameter fetching of a `SamplingTree`."""
Expand Down

0 comments on commit 7650577

Please sign in to comment.