Skip to content

Commit

Permalink
add time variable test for synapse
Browse files Browse the repository at this point in the history
  • Loading branch information
C.A.P. Linssen committed Apr 8, 2024
1 parent 03bc52d commit 7e81a2f
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 15 deletions.
4 changes: 1 addition & 3 deletions tests/nest_tests/resources/TimeVariablePrePostSynapse.nestml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ model time_variable_pre_post_synapse:
x ms = 0 ms
y ms = 0 ms
z ms = 0 ms
q ms = 0 ms

onReceive(pre_spikes):
y = t
Expand All @@ -44,5 +43,4 @@ model time_variable_pre_post_synapse:
post_spikes <- spike

update:
x = t + resolution() + 1 ms
q = x
x = t + resolution()
2 changes: 1 addition & 1 deletion tests/nest_tests/resources/TimeVariableSynapse.nestml
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,4 @@ model time_variable_synapse:
pre_spikes <- spike

update:
x = t + resolution() + 1 ms
x = t + resolution()
102 changes: 91 additions & 11 deletions tests/nest_tests/test_time_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import numpy as np
import os
import pytest
import scipy.signal

import nest

Expand Down Expand Up @@ -56,6 +57,7 @@ def setUp(self):

def test_time_variable_neuron(self):
nest.ResetKernel()
nest.resolution = .25 # [ms]
try:
nest.Install("nestmlmodule")
except Exception:
Expand All @@ -78,6 +80,7 @@ def test_time_variable_neuron(self):
def test_time_variable_synapse(self):
"""a synapse is only updated when presynaptic spikes arrive"""
nest.ResetKernel()
nest.resolution = .25 # [ms]
try:
nest.Install("nestmlmodule")
except Exception:
Expand All @@ -89,15 +92,19 @@ def test_time_variable_synapse(self):
nest.Connect(nrn[0], sr)
nest.Connect(nrn[0], nrn[1], syn_spec={"synapse_model": "time_variable_synapse_nestml"})
syn = nest.GetConnections(nrn[0], nrn[1])
syn.delay = nest.resolution # [ms]
syn.d = nest.resolution
assert len(syn) == 1

sr_pre = nest.Create("spike_recorder")
sr_post = nest.Create("spike_recorder")
nest.Connect(nrn[0], sr_pre)
nest.Connect(nrn[1], sr_post)

nest.set_verbosity("M_FATAL")

T_sim = 50. # [ms]
sim_interval = 1. # [ms]
sim_interval = nest.resolution # [ms]
timevec = [0.]
x = [syn[0].get("x")]
y = [syn[0].get("y")]
Expand All @@ -109,6 +116,23 @@ def test_time_variable_synapse(self):

assert len(sr.get("events")["times"]) > 2, "Was expecting some more presynaptic spikes"

#
# analysis
#

timevec = np.array(timevec)
x = np.array(x)
y = np.array(y)
x_error = np.abs(timevec - x)
y_error = np.abs(timevec - y)

x_peaks_idx, _ = scipy.signal.find_peaks(-x_error)
y_peaks_idx, _ = scipy.signal.find_peaks(-y_error)

#
# plot
#

fig, ax = plt.subplots(nrows=4, figsize=(8, 8))

ax[0].scatter(sr_pre.get("events")["times"], np.zeros_like(sr_pre.get("events")["times"]))
Expand All @@ -119,9 +143,15 @@ def test_time_variable_synapse(self):

ax[2].plot(timevec, x, label="x")
ax[2].plot(timevec, timevec, linestyle="--", c="gray")
ax2_ = ax[2].twinx()
ax2_.plot(timevec, x_error, c="red")
ax2_.scatter(timevec[x_peaks_idx], x_error[x_peaks_idx], edgecolor="red", facecolor="none")

ax[3].plot(timevec, y, label="y")
ax[3].plot(timevec, timevec, linestyle="--", c="gray")
ax3_ = ax[3].twinx()
ax3_.plot(timevec, y_error, c="red")
ax3_.scatter(timevec[y_peaks_idx], y_error[y_peaks_idx], edgecolor="red", facecolor="none")

ax[-1].set_ylabel("Time [ms]")

Expand All @@ -130,46 +160,70 @@ def test_time_variable_synapse(self):
_ax.legend()
_ax.set_xlim(-1, T_sim + 1)

fig.savefig("/tmp/foo1.png")
fig.savefig("/tmp/test_time_variable_synapse.png")

# np.testing.assert_allclose(x, sr.get("events")["times"][-2])
# np.testing.assert_allclose(y, sr.get("events")["times"][-1])
#
# testing
#

assert all(x_error[x_peaks_idx] <= nest.resolution + 1E-12)
assert all(y_error[y_peaks_idx] <= nest.resolution + 1E-12)

def test_time_variable_pre_post_synapse(self):
"""a synapse is updated when pre- and postsynaptic spikes arrive"""
nest.ResetKernel()
nest.resolution = .25 # [ms]
try:
nest.Install("nestmlmodule")
except Exception:
# ResetKernel() does not unload modules for NEST Simulator < v3.7; ignore exception if module is already loaded on earlier versions
pass
nrn = nest.Create("iaf_psc_delta_neuron_nestml__with_time_variable_pre_post_synapse_nestml", 2)
nrn[0].I_e = 1000. # [pA]
nrn[1].I_e = 2000. # [pA]
nrn[1].I_e = 1500. # [pA]
sr_pre = nest.Create("spike_recorder")
sr_post = nest.Create("spike_recorder")
nest.Connect(nrn[0], sr_pre)
nest.Connect(nrn[1], sr_post)
nest.Connect(nrn[0], nrn[1], syn_spec={"synapse_model": "time_variable_pre_post_synapse_nestml__with_iaf_psc_delta_neuron_nestml"})
syn = nest.GetConnections(nrn[0], nrn[1])
syn.delay = nest.resolution # [ms]
syn.d = nest.resolution
assert len(syn) == 1

T_sim = 20. # [ms]
sim_interval = 1. # [ms]
sim_interval = nest.resolution # [ms]
timevec = [0.]
x = [syn[0].get("x")]
y = [syn[0].get("y")]
z = [syn[0].get("z")]
# n_post_spikes = [syn[0].get("n_post_spikes")]
while nest.biological_time < T_sim:
nest.Simulate(sim_interval)
timevec.append(nest.biological_time)
x.append(syn[0].get("x"))
y.append(syn[0].get("y"))
z.append(syn[0].get("z"))
# n_post_spikes.append(syn[0].get("n_post_spikes"))

#
# analysis
#

# assert len(sr_pre.get("events")["times"]) > 2, "Was expecting some more presynaptic spikes"
# assert len(sr_post.get("events")["times"]) > 2, "Was expecting some more presynaptic spikes"
timevec = np.array(timevec)
x = np.array(x)
y = np.array(y)
x_error = np.abs(timevec - x)
y_error = np.abs(timevec - y)
z_error = np.abs(timevec - z)

x_peaks_idx, _ = scipy.signal.find_peaks(-x_error)
y_peaks_idx, _ = scipy.signal.find_peaks(-y_error)
z_peaks_idx, _ = scipy.signal.find_peaks(-z_error)

#
# plot
#

fig, ax = plt.subplots(nrows=5, figsize=(8, 8))

Expand All @@ -181,12 +235,21 @@ def test_time_variable_pre_post_synapse(self):

ax[2].plot(timevec, x, label="x")
ax[2].plot(timevec, timevec, linestyle="--", c="gray")
ax2_ = ax[2].twinx()
ax2_.plot(timevec, x_error, c="red")
ax2_.scatter(timevec[x_peaks_idx], x_error[x_peaks_idx], edgecolor="red", facecolor="none")

ax[3].plot(timevec, y, label="y")
ax[3].plot(timevec, timevec, linestyle="--", c="gray")
ax3_ = ax[3].twinx()
ax3_.plot(timevec, y_error, c="red")
ax3_.scatter(timevec[y_peaks_idx], y_error[y_peaks_idx], edgecolor="red", facecolor="none")

ax[4].plot(timevec, z, label="z")
ax[4].plot(timevec, timevec, linestyle="--", c="gray")
ax4_ = ax[4].twinx()
ax4_.plot(timevec, z_error, c="red")
ax4_.scatter(timevec[z_peaks_idx], z_error[z_peaks_idx], edgecolor="red", facecolor="none")

ax[-1].set_ylabel("Time [ms]")

Expand All @@ -195,7 +258,24 @@ def test_time_variable_pre_post_synapse(self):
_ax.legend()
_ax.set_xlim(-1, T_sim + 1)

fig.savefig("/tmp/foo.png")
fig.savefig("/tmp/test_time_variable_pre_post_synapse.png")

#
# testing
#

assert all(x_error[x_peaks_idx] <= nest.resolution + 1E-12)
assert all(y_error[y_peaks_idx] <= nest.resolution + 1E-12)

# for z, which is assigned to on post spike times but only actually updated when a pre spike is processed, compare to the last post spike time
post_spike_times = sr_post.get("events")["times"]
for i, t in enumerate(timevec[z_peaks_idx]):
# find the last post spike before t
for i_, t_post_sp in enumerate(post_spike_times):
if t_post_sp > t:
t_post_sp = post_spike_times[i_ - 1]
break

time_interval = t - t_post_sp

np.testing.assert_allclose(x, sr.get("events")["times"][-2])
np.testing.assert_allclose(y, sr.get("events")["times"][-1])
assert z_error[z_peaks_idx[i]] <= time_interval + nest.resolution + 1E-12

0 comments on commit 7e81a2f

Please sign in to comment.