Skip to content

Commit

Permalink
add cluster's name in draw() method
Browse files Browse the repository at this point in the history
  • Loading branch information
bezumbzalinho committed Jul 28, 2024
1 parent 27462bb commit 97e9716
Showing 1 changed file with 21 additions and 13 deletions.
34 changes: 21 additions & 13 deletions river/cluster/odac.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def render_ascii(self, n_decimal_places: int = 2) -> str:

return self._root_node.design_structure(n_decimal_places).rstrip("\n")

def draw(self, max_depth: int | None = None, show_clusters_info: list[str] = ["timeseries_names", "d1", "d2", "e"], n_decimal_places: int = 2):
def draw(self, max_depth: int | None = None, show_clusters_info: list[typing.Hashable] = ["timeseries_names", "d1", "d2", "e"], n_decimal_places: int = 2):
"""Method to draw the hierarchical cluster's structure as a Graphviz graph.
Parameters
Expand All @@ -250,10 +250,11 @@ def draw(self, max_depth: int | None = None, show_clusters_info: list[str] = ["t
The maximum depth of the tree to display.
show_clusters_info
List of cluster information to show. Valid options are:
- "timeseries_indexes": Shows the indexes of the timeseries.
- "timeseries_names": Shows the names of the timeseries.
- "d1": Shows the d1 (the largest distance in each cluster).
- "d2": Shows the d2 (the second largest distance in each cluster).
- "timeseries_indexes": Shows the indexes of the timeseries in the cluster.
- "timeseries_names": Shows the names of the timeseries in the cluster.
- "name": Shows the cluster's name.
- "d1": Shows the d1 (the largest distance in the cluster).
- "d2": Shows the d2 (the second largest distance in the cluster).
- "e": Shows the error bound.
n_decimal_places
The number of decimal places to show for numerical values.
Expand Down Expand Up @@ -292,15 +293,20 @@ def iterate(node: ODACCluster, parent_node: str | None = None, depth: int = 0):

label = ""

show_clusters_info_copy = show_clusters_info.copy()

# checks if user wants to see information about clusters
if len(show_clusters_info_copy) > 0:
if len(show_clusters_info) > 0:
show_clusters_info_copy = show_clusters_info.copy()

if "name" in show_clusters_info_copy:
label += f"{node.name}"
show_clusters_info_copy.remove("name")
if len(show_clusters_info_copy) > 0:
label += "\n"
if "timeseries_indexes" in show_clusters_info_copy:
# Convert timeseries names to indexes
name_to_index = {name: index for index, name in enumerate(self._root_node.timeseries_names)}
timeseries_indexes = [name_to_index[_name] for _name in node.timeseries_names if _name in name_to_index]

label += f"{timeseries_indexes}"
show_clusters_info_copy.remove("timeseries_indexes")
if len(show_clusters_info_copy) > 0:
Expand All @@ -326,9 +332,9 @@ def iterate(node: ODACCluster, parent_node: str | None = None, depth: int = 0):
if "e" in show_clusters_info_copy:
label += f"e={node.e:.{n_decimal_places}f}"

show_clusters_info_copy.clear()
show_clusters_info_copy.clear()

# Creates a node with different colors to differentiate the active clusters from the non-active
# Creates a node with different color to differentiate the active clusters from the non-active
if node.active:
dot.node(node_n, label, style="filled", fillcolor="#76b5c5")
else:
Expand Down Expand Up @@ -360,7 +366,7 @@ def __init__(self, name: str, parent: ODACCluster | None = None):
self.children: ODACChildren | None = None

self.timeseries_names: list[typing.Hashable] = []
self._statistics: dict[tuple[typing.Hashable, typing.Hashable], stats.PearsonCorr] | stats.Var
self._statistics: dict[tuple[typing.Hashable, typing.Hashable], stats.PearsonCorr] | stats.Var | None

self.d1: float | None = None
self.d2: float | None = None
Expand Down Expand Up @@ -524,7 +530,9 @@ def _split_this_cluster(self, pivot_1: typing.Hashable, pivot_2: typing.Hashable

# Set the active flag to false. Since this cluster is not an active cluster anymore.
self.active = False
self.avg = self.d0 = self.pivot_0 = self.pivot_1 = self.pivot_2 = None # type: ignore

# Reset some attributes
self.avg = self.d0 = self.pivot_0 = self.pivot_1 = self.pivot_2 = self._statistics = None # type: ignore

# Method that proceeds to merge on this cluster
def _aggregate_this_cluster(self):
Expand Down

0 comments on commit 97e9716

Please sign in to comment.