diff --git a/mlipx/nodes/adsorption.py b/mlipx/nodes/adsorption.py index d26e6c9..6f97caf 100644 --- a/mlipx/nodes/adsorption.py +++ b/mlipx/nodes/adsorption.py @@ -4,6 +4,7 @@ import ase import ase.io as aio import ase.optimize as opt +import numpy as np import plotly.graph_objects as go import zntrack @@ -214,6 +215,8 @@ def compare(cls, *nodes: "RelaxAdsorptionConfigs") -> ComparisonResults: relax_figures = {} + offset = 0 + for key in nodes[0].relaxations: for node in nodes: traj = node.relaxations[key] @@ -252,18 +255,64 @@ def compare(cls, *nodes: "RelaxAdsorptionConfigs") -> ComparisonResults: x=list(range(len(energies))), y=energies, name=node.name.replace(f"_{cls.__name__}", ""), + customdata=np.stack( + [np.arange(len(energies)) + offset], axis=1 + ), + mode="lines+markers", ) ) + fig.update_layout( + { + "plot_bgcolor": "rgba(0, 0, 0, 0)", + "paper_bgcolor": "rgba(0, 0, 0, 0)", + "yaxis_title": "Adsorption Energy / a.u.", + "xaxis_title": "Step", + } + ) + fig.update_xaxes( + showgrid=True, + gridwidth=1, + gridcolor="rgba(120, 120, 120, 0.3)", + zeroline=False, + ) + fig.update_yaxes( + showgrid=True, + gridwidth=1, + gridcolor="rgba(120, 120, 120, 0.3)", + zeroline=False, + ) + relax_figures[config_site if config_site else config_type] = fig full_traj.extend(traj) + offset += len(traj) ads_e[node.name.replace(f"_{cls.__name__}", "")] = node.ads_energies fig = go.Figure() for key, val in ads_e.items(): fig.add_trace(go.Bar(x=list(val.keys()), y=list(val.values()), name=key)) + fig.update_layout( + { + "plot_bgcolor": "rgba(0, 0, 0, 0)", + "paper_bgcolor": "rgba(0, 0, 0, 0)", + "yaxis_title": "Adsorption Energy / a.u.", + } + ) + fig.update_xaxes( + showgrid=True, + gridwidth=1, + gridcolor="rgba(120, 120, 120, 0.3)", + zeroline=False, + ) + fig.update_yaxes( + showgrid=True, + gridwidth=1, + gridcolor="rgba(120, 120, 120, 0.3)", + zeroline=False, + ) + relax_figures["adsorption_energies"] = fig return {"frames": full_traj, "figures": relax_figures}