Skip to content

Commit

Permalink
Add titles to figures
Browse files Browse the repository at this point in the history
Co-authored-by: clinssen <[email protected]>
  • Loading branch information
akorgor and clinssen committed Jun 21, 2024
1 parent 980d03a commit 3b24f4b
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,7 @@ def get_weights(pop_pre, pop_post):
# plotted against the iterations.

fig, axs = plt.subplots(2, 1, sharex=True)
fig.suptitle("Training error")

axs[0].plot(range(1, n_iter + 1), loss)
axs[0].set_ylabel(r"$E = \frac{1}{2} \sum_{t,k} \left( y_k^t -y_k^{*,t}\right)^2$")
Expand Down Expand Up @@ -722,8 +723,12 @@ def plot_spikes(ax, events, ylabel, xlims):
ax.set_ylim(np.min(senders_subset) - margin, np.max(senders_subset) + margin)


for xlims in [(0, steps["sequence"]), (steps["task"] - steps["sequence"], steps["task"])]:
for title, xlims in zip(
["Dynamic variables before training", "Dynamic variables after training"],
[(0, steps["sequence"]), (steps["task"] - steps["sequence"], steps["task"])],
):
fig, axs = plt.subplots(14, 1, sharex=True, figsize=(8, 14), gridspec_kw={"hspace": 0.4, "left": 0.2})
fig.suptitle(title)

plot_spikes(axs[0], events_sr_in, r"$z_i$" + "\n", xlims)
plot_spikes(axs[1], events_sr_reg, r"$z_j$" + "\n", xlims)
Expand Down Expand Up @@ -775,6 +780,7 @@ def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabe


fig, axs = plt.subplots(3, 1, sharex=True, figsize=(3, 4))
fig.suptitle("Weight time courses")

plot_weight_time_course(axs[0], events_wr, nrns_in[:n_record_w], nrns_rec[:n_record_w], "in_rec", r"$W_\text{in}$ (pA)")
plot_weight_time_course(
Expand All @@ -800,6 +806,7 @@ def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabe
)

fig, axs = plt.subplots(3, 2, sharex="col", sharey="row")
fig.suptitle("Weight matrices")

all_w_extrema = []

Expand All @@ -822,8 +829,8 @@ def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabe
axs[2, 0].set_ylabel("readout\nneurons")
fig.align_ylabels(axs[:, 0])

axs[0, 0].text(0.5, 1.1, "pre-training", transform=axs[0, 0].transAxes, ha="center")
axs[0, 1].text(0.5, 1.1, "post-training", transform=axs[0, 1].transAxes, ha="center")
axs[0, 0].text(0.5, 1.1, "before training", transform=axs[0, 0].transAxes, ha="center")
axs[0, 1].text(0.5, 1.1, "after training", transform=axs[0, 1].transAxes, ha="center")

axs[2, 0].yaxis.get_major_locator().set_params(integer=True)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,7 @@ def get_weights(pop_pre, pop_post):
# plotted against the iterations.

fig, axs = plt.subplots(2, 1, sharex=True)
fig.suptitle("Training error")

axs[0].plot(range(1, n_iter + 1), loss)
axs[0].set_ylabel(r"$E = -\sum_{t,k} \pi_k^{*,t} \log \pi_k^t$")
Expand Down Expand Up @@ -700,8 +701,12 @@ def plot_spikes(ax, events, ylabel, xlims):
ax.set_ylim(np.min(senders_subset) - margin, np.max(senders_subset) + margin)


for xlims in [(0, steps["sequence"]), (steps["task"] - steps["sequence"], steps["task"])]:
for title, xlims in zip(
["Dynamic variables before training", "Dynamic variables after training"],
[(0, steps["sequence"]), (steps["task"] - steps["sequence"], steps["task"])],
):
fig, axs = plt.subplots(14, 1, sharex=True, figsize=(8, 14), gridspec_kw={"hspace": 0.4, "left": 0.2})
fig.suptitle(title)

plot_spikes(axs[0], events_sr_in, r"$z_i$" + "\n", xlims)
plot_spikes(axs[1], events_sr_reg, r"$z_j$" + "\n", xlims)
Expand Down Expand Up @@ -753,6 +758,7 @@ def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabe


fig, axs = plt.subplots(3, 1, sharex=True, figsize=(3, 4))
fig.suptitle("Weight time courses")

plot_weight_time_course(axs[0], events_wr, nrns_in[:n_record_w], nrns_rec[:n_record_w], "in_rec", r"$W_\text{in}$ (pA)")
plot_weight_time_course(
Expand All @@ -778,6 +784,7 @@ def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabe
)

fig, axs = plt.subplots(3, 2, sharex="col", sharey="row")
fig.suptitle("Weight matrices")

all_w_extrema = []

Expand All @@ -800,8 +807,8 @@ def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabe
axs[2, 0].set_ylabel("readout\nneurons")
fig.align_ylabels(axs[:, 0])

axs[0, 0].text(0.5, 1.1, "pre-training", transform=axs[0, 0].transAxes, ha="center")
axs[0, 1].text(0.5, 1.1, "post-training", transform=axs[0, 1].transAxes, ha="center")
axs[0, 0].text(0.5, 1.1, "before training", transform=axs[0, 0].transAxes, ha="center")
axs[0, 1].text(0.5, 1.1, "after training", transform=axs[0, 1].transAxes, ha="center")

axs[2, 0].yaxis.get_major_locator().set_params(integer=True)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,7 @@ def evaluate(n_iteration, iter_start):
# plotted against the iterations.

fig, axs = plt.subplots(2, 1, sharex=True)
fig.suptitle("Training error")

axs[0].plot(range(1, n_iter + 1), loss)
axs[0].set_ylabel(r"$E = \frac{1}{2} \sum_{t,k} \left( y_k^t -y_k^{*,t}\right)^2$")
Expand Down Expand Up @@ -784,11 +785,15 @@ def plot_spikes(ax, events, ylabel, xlims):
ax.set_ylim(np.min(senders_subset) - margin, np.max(senders_subset) + margin)


for xlims in [
(steps["pre_sim"], steps["pre_sim"] + steps["sequence"]),
(steps["pre_sim"] + steps["task"] - steps["sequence"], steps["pre_sim"] + steps["task"]),
]:
for title, xlims in zip(
["Dynamic variables before training", "Dynamic variables after training"],
[
(steps["pre_sim"], steps["pre_sim"] + steps["sequence"]),
(steps["pre_sim"] + steps["task"] - steps["sequence"], steps["pre_sim"] + steps["task"]),
],
):
fig, axs = plt.subplots(9, 1, sharex=True, figsize=(8, 14), gridspec_kw={"hspace": 0.4, "left": 0.2})
fig.suptitle(title)

plot_spikes(axs[0], events_sr_in, r"$z_i$" + "\n", xlims)
plot_spikes(axs[1], events_sr_rec, r"$z_j$" + "\n", xlims)
Expand Down Expand Up @@ -833,6 +838,7 @@ def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabe


fig, axs = plt.subplots(3, 1, sharex=True, figsize=(3, 4))
fig.suptitle("Weight time courses")

plot_weight_time_course(axs[0], events_wr, nrns_in[:n_record_w], nrns_rec[:n_record_w], "in_rec", r"$W_\text{in}$ (pA)")
plot_weight_time_course(
Expand All @@ -858,6 +864,7 @@ def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabe
)

fig, axs = plt.subplots(3, 2, sharex="col", sharey="row")
fig.suptitle("Weight matrices")

all_w_extrema = []

Expand All @@ -880,8 +887,8 @@ def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabe
axs[2, 0].set_ylabel("readout\nneurons")
fig.align_ylabels(axs[:, 0])

axs[0, 0].text(0.5, 1.1, "pre-training", transform=axs[0, 0].transAxes, ha="center")
axs[0, 1].text(0.5, 1.1, "post-training", transform=axs[0, 1].transAxes, ha="center")
axs[0, 0].text(0.5, 1.1, "before training", transform=axs[0, 0].transAxes, ha="center")
axs[0, 1].text(0.5, 1.1, "after training", transform=axs[0, 1].transAxes, ha="center")

axs[2, 0].yaxis.get_major_locator().set_params(integer=True)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,7 @@ def get_weights(pop_pre, pop_post):
# neurons encode the horizontal and vertical coordinate of the pattern respectively.

fig, ax = plt.subplots()
fig.suptitle("Pattern")

ax.plot(readout_signal[0, -1, 0, :], -readout_signal[1, -1, 0, :], c=colors["red"], label="readout")

Expand All @@ -581,6 +582,7 @@ def get_weights(pop_pre, pop_post):
# We begin with a plot visualizing the training error of the network: the loss plotted against the iterations.

fig, ax = plt.subplots()
fig.suptitle("Training error")

ax.plot(range(1, n_iter + 1), loss_list[0], label=r"$E_0$", alpha=0.8, c=colors["blue"], ls="--")
ax.plot(range(1, n_iter + 1), loss_list[1], label=r"$E_1$", alpha=0.8, c=colors["blue"], ls="dotted")
Expand Down Expand Up @@ -621,8 +623,12 @@ def plot_spikes(ax, events, ylabel, xlims):
ax.set_ylim(np.min(senders_subset) - margin, np.max(senders_subset) + margin)


for xlims in [(0, steps["sequence"]), (steps["task"] - steps["sequence"], steps["task"])]:
for title, xlims in zip(
["Dynamic variables before training", "Dynamic variables after training"],
[(0, steps["sequence"]), (steps["task"] - steps["sequence"], steps["task"])],
):
fig, axs = plt.subplots(10, 1, sharex=True, figsize=(8, 12), gridspec_kw={"hspace": 0.4, "left": 0.2})
fig.suptitle(title)

plot_spikes(axs[0], events_sr_in, r"$z_i$" + "\n", xlims)
plot_spikes(axs[1], events_sr_rec, r"$z_j$" + "\n", xlims)
Expand Down Expand Up @@ -668,6 +674,7 @@ def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabe


fig, axs = plt.subplots(3, 1, sharex=True, figsize=(3, 4))
fig.suptitle("Weight time courses")

plot_weight_time_course(axs[0], events_wr, nrns_in[:n_record_w], nrns_rec[:n_record_w], "in_rec", r"$W_\text{in}$ (pA)")
plot_weight_time_course(
Expand All @@ -693,6 +700,7 @@ def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabe
)

fig, axs = plt.subplots(3, 2, sharex="col", sharey="row")
fig.suptitle("Weight matrices")

all_w_extrema = []

Expand All @@ -715,8 +723,8 @@ def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabe
axs[2, 0].set_ylabel("readout\nneurons")
fig.align_ylabels(axs[:, 0])

axs[0, 0].text(0.5, 1.1, "pre-training", transform=axs[0, 0].transAxes, ha="center")
axs[0, 1].text(0.5, 1.1, "post-training", transform=axs[0, 1].transAxes, ha="center")
axs[0, 0].text(0.5, 1.1, "before training", transform=axs[0, 0].transAxes, ha="center")
axs[0, 1].text(0.5, 1.1, "after training", transform=axs[0, 1].transAxes, ha="center")

axs[2, 0].yaxis.get_major_locator().set_params(integer=True)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,7 @@ def get_weights(pop_pre, pop_post):
# neurons encode the horizontal and vertical coordinate of the pattern respectively.

fig, ax = plt.subplots()
fig.suptitle("Pattern")

ax.plot(readout_signal[0, -1, 0, :], -readout_signal[1, -1, 0, :], c=colors["red"], label="readout")

Expand All @@ -572,6 +573,7 @@ def get_weights(pop_pre, pop_post):
# We begin with a plot visualizing the training error of the network: the loss plotted against the iterations.

fig, ax = plt.subplots()
fig.suptitle("Training error")

ax.plot(range(1, n_iter + 1), loss_list[0], label=r"$E_0$", alpha=0.8, c=colors["blue"], ls="--")
ax.plot(range(1, n_iter + 1), loss_list[1], label=r"$E_1$", alpha=0.8, c=colors["blue"], ls="dotted")
Expand Down Expand Up @@ -612,8 +614,12 @@ def plot_spikes(ax, events, ylabel, xlims):
ax.set_ylim(np.min(senders_subset) - margin, np.max(senders_subset) + margin)


for xlims in [(0, steps["sequence"]), (steps["task"] - steps["sequence"], steps["task"])]:
for title, xlims in zip(
["Dynamic variables before training", "Dynamic variables after training"],
[(0, steps["sequence"]), (steps["task"] - steps["sequence"], steps["task"])],
):
fig, axs = plt.subplots(10, 1, sharex=True, figsize=(8, 12), gridspec_kw={"hspace": 0.4, "left": 0.2})
fig.suptitle(title)

plot_spikes(axs[0], events_sr_in, r"$z_i$" + "\n", xlims)
plot_spikes(axs[1], events_sr_rec, r"$z_j$" + "\n", xlims)
Expand Down Expand Up @@ -659,6 +665,7 @@ def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabe


fig, axs = plt.subplots(3, 1, sharex=True, figsize=(3, 4))
fig.suptitle("Weight time courses")

plot_weight_time_course(axs[0], events_wr, nrns_in[:n_record_w], nrns_rec[:n_record_w], "in_rec", r"$W_\text{in}$ (pA)")
plot_weight_time_course(
Expand All @@ -684,6 +691,7 @@ def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabe
)

fig, axs = plt.subplots(3, 2, sharex="col", sharey="row")
fig.suptitle("Weight matrices")

all_w_extrema = []

Expand All @@ -706,8 +714,8 @@ def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabe
axs[2, 0].set_ylabel("readout\nneurons")
fig.align_ylabels(axs[:, 0])

axs[0, 0].text(0.5, 1.1, "pre-training", transform=axs[0, 0].transAxes, ha="center")
axs[0, 1].text(0.5, 1.1, "post-training", transform=axs[0, 1].transAxes, ha="center")
axs[0, 0].text(0.5, 1.1, "before training", transform=axs[0, 0].transAxes, ha="center")
axs[0, 1].text(0.5, 1.1, "after training", transform=axs[0, 1].transAxes, ha="center")

axs[2, 0].yaxis.get_major_locator().set_params(integer=True)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,7 @@ def get_weights(pop_pre, pop_post):
# We begin with a plot visualizing the training error of the network: the loss plotted against the iterations.

fig, ax = plt.subplots()
fig.suptitle("Training error")

ax.plot(range(1, n_iter + 1), loss)
ax.set_ylabel(r"$E = \frac{1}{2} \sum_{t,k} \left( y_k^t -y_k^{*,t}\right)^2$")
Expand Down Expand Up @@ -590,8 +591,12 @@ def plot_spikes(ax, events, ylabel, xlims):
ax.set_ylim(np.min(senders_subset) - margin, np.max(senders_subset) + margin)


for xlims in [(0, steps["sequence"]), (steps["task"] - steps["sequence"], steps["task"])]:
for title, xlims in zip(
["Dynamic variables before training", "Dynamic variables after training"],
[(0, steps["sequence"]), (steps["task"] - steps["sequence"], steps["task"])],
):
fig, axs = plt.subplots(9, 1, sharex=True, figsize=(6, 8), gridspec_kw={"hspace": 0.4, "left": 0.2})
fig.suptitle(title)

plot_spikes(axs[0], events_sr_in, r"$z_i$" + "\n", xlims)
plot_spikes(axs[1], events_sr_rec, r"$z_j$" + "\n", xlims)
Expand Down Expand Up @@ -636,6 +641,7 @@ def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabe


fig, axs = plt.subplots(3, 1, sharex=True, figsize=(3, 4))
fig.suptitle("Weight time courses")

plot_weight_time_course(axs[0], events_wr, nrns_in[:n_record_w], nrns_rec[:n_record_w], "in_rec", r"$W_\text{in}$ (pA)")
plot_weight_time_course(
Expand All @@ -661,6 +667,7 @@ def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabe
)

fig, axs = plt.subplots(3, 2, sharex="col", sharey="row")
fig.suptitle("Weight matrices")

all_w_extrema = []

Expand All @@ -683,8 +690,8 @@ def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabe
axs[2, 0].set_ylabel("readout\nneurons")
fig.align_ylabels(axs[:, 0])

axs[0, 0].text(0.5, 1.1, "pre-training", transform=axs[0, 0].transAxes, ha="center")
axs[0, 1].text(0.5, 1.1, "post-training", transform=axs[0, 1].transAxes, ha="center")
axs[0, 0].text(0.5, 1.1, "before training", transform=axs[0, 0].transAxes, ha="center")
axs[0, 1].text(0.5, 1.1, "after training", transform=axs[0, 1].transAxes, ha="center")

axs[2, 0].yaxis.get_major_locator().set_params(integer=True)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,7 @@ def get_weights(pop_pre, pop_post):
# We begin with a plot visualizing the training error of the network: the loss plotted against the iterations.

fig, ax = plt.subplots()
fig.suptitle("Training error")

ax.plot(range(1, n_iter + 1), loss)
ax.set_ylabel(r"$E = \frac{1}{2} \sum_{t,k} \left( y_k^t -y_k^{*,t}\right)^2$")
Expand Down Expand Up @@ -563,8 +564,12 @@ def plot_spikes(ax, events, ylabel, xlims):
ax.set_ylim(np.min(senders_subset) - margin, np.max(senders_subset) + margin)


for xlims in [(0, steps["sequence"]), (steps["task"] - steps["sequence"], steps["task"])]:
for title, xlims in zip(
["Dynamic variables before training", "Dynamic variables after training"],
[(0, steps["sequence"]), (steps["task"] - steps["sequence"], steps["task"])],
):
fig, axs = plt.subplots(9, 1, sharex=True, figsize=(6, 8), gridspec_kw={"hspace": 0.4, "left": 0.2})
fig.suptitle(title)

plot_spikes(axs[0], events_sr_in, r"$z_i$" + "\n", xlims)
plot_spikes(axs[1], events_sr_rec, r"$z_j$" + "\n", xlims)
Expand Down Expand Up @@ -609,6 +614,7 @@ def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabe


fig, axs = plt.subplots(3, 1, sharex=True, figsize=(3, 4))
fig.suptitle("Weight time courses")

plot_weight_time_course(axs[0], events_wr, nrns_in[:n_record_w], nrns_rec[:n_record_w], "in_rec", r"$W_\text{in}$ (pA)")
plot_weight_time_course(
Expand All @@ -634,6 +640,7 @@ def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabe
)

fig, axs = plt.subplots(3, 2, sharex="col", sharey="row")
fig.suptitle("Weight matrices")

all_w_extrema = []

Expand All @@ -656,8 +663,8 @@ def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabe
axs[2, 0].set_ylabel("readout\nneurons")
fig.align_ylabels(axs[:, 0])

axs[0, 0].text(0.5, 1.1, "pre-training", transform=axs[0, 0].transAxes, ha="center")
axs[0, 1].text(0.5, 1.1, "post-training", transform=axs[0, 1].transAxes, ha="center")
axs[0, 0].text(0.5, 1.1, "before training", transform=axs[0, 0].transAxes, ha="center")
axs[0, 1].text(0.5, 1.1, "after training", transform=axs[0, 1].transAxes, ha="center")

axs[2, 0].yaxis.get_major_locator().set_params(integer=True)

Expand Down

0 comments on commit 3b24f4b

Please sign in to comment.