diff --git a/cotengra/plot.py b/cotengra/plot.py index 8757ed4..5e4704c 100644 --- a/cotengra/plot.py +++ b/cotengra/plot.py @@ -1146,6 +1146,7 @@ def plot_contractions( color_size=(0.6, 0.4, 0.7), color_cost=(0.3, 0.7, 0.5), figsize=(8, 3), + ax=None, ): import matplotlib.pyplot as plt @@ -1160,12 +1161,15 @@ def plot_contractions( sz -= tree.get_size(l) sz -= tree.get_size(r) sizes.append(math.log2(tree.get_size(p))) + costs.append(math.log10(tree.get_flops(p))) cons = list(range(len(peaks))) - fig, ax = plt.subplots(figsize=figsize) - + if ax is None: + fig, ax = plt.subplots(figsize=figsize) + else: + fig = None ax.set_xlabel("contraction") ax.plot( @@ -1552,8 +1556,6 @@ def plot_tree_rubberband( tree, order=None, colormap="Spectral", - with_edge_labels=None, - with_node_labels=None, highlight=(), centrality=False, layout="auto", @@ -1598,6 +1600,7 @@ def plot_tree_rubberband( figsize=figsize, info=info, show_and_close=False, + ax=ax, ) pos = info["pos"] r0 = info["node_size"] @@ -1647,6 +1650,7 @@ def plot_tree_flat( node_labels_font_family="monospace", show_sliced=True, figsize=None, + ax=None, ): """Plot a ``ContractionTree`` as a flat, 2D diagram, including all indices at every intermediate contraction. This can be useful for small @@ -1720,7 +1724,7 @@ def plot_tree_flat( family=node_labels_font_family, ) - d = Drawing(figsize=figsize) + d = Drawing(ax=ax, figsize=figsize) # order the leaves are contracted in leaf_order = {leaf: i for i, leaf in enumerate(tree.get_leaves_ordered())} @@ -1877,6 +1881,7 @@ def plot_tree_circuit( node_colormap="YlOrRd", node_max_size=None, figsize=None, + ax=None, ): import matplotlib as mpl @@ -1885,7 +1890,7 @@ def plot_tree_circuit( if figsize is None: figsize = (tree.N**0.75, tree.N**0.75) - d = Drawing(figsize=figsize) + d = Drawing(ax=ax, figsize=figsize) # edge coloring -> node size if edge_max_width is None: