Skip to content

Commit

Permalink
improve test_iaf_psc_alpha_multisynapse
Browse files Browse the repository at this point in the history
  • Loading branch information
janeirik committed Nov 25, 2023
1 parent b40e2b9 commit 09f83d6
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,9 @@ def test_simulation_against_analytical_solution():
nest.Connect(mm, nrn, syn_spec={"delay": 0.1})
nest.Simulate(simtime)
times = mm.get("events", "times")

I_syns_analytical = []
V_m_analytical = np.zeros_like(times, dtype=np.float64)
V_m_analytical = np.zeros_like(times)
for weight, delay, tau_s in zip(weights, delays, tau_syns):
I_syns_analytical.append(exp_psc_fn(times - delay - spike_time, tau_s) * weight)
V_m_analytical += exp_psc_voltage_response(times - delay - spike_time, tau_s, tau_m, C_m, weight)
Expand All @@ -133,6 +134,7 @@ def test_default_recordables():
assert "I_syn_1" in recordables
assert "V_m" in recordables


def test_resize_recordables():
"""
Test resizing of recordables.
Expand Down
65 changes: 46 additions & 19 deletions testsuite/pytests/test_iaf_psc_alpha_multisynapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import numpy as np
import numpy.testing as nptest
import pytest
from scipy.linalg import expm


@pytest.fixture(autouse=True)
Expand All @@ -37,8 +38,7 @@ def reset():

def alpha_fn(t, tau_syn):
vals = np.zeros_like(t)
zero_inds = t <= 0.0
nonzero_inds = ~zero_inds
nonzero_inds = t > 0.0
vals[nonzero_inds] = np.e / tau_syn * t[nonzero_inds] * np.exp(-t[nonzero_inds] / tau_syn)
return vals

Expand All @@ -50,6 +50,16 @@ def test_I_syn_1_in_recordables():
assert "I_syn_1" in nrn.get("recordables")


def alpha_psc_voltage_response(t, tau_syn, tau_m, C_m, w):
vals = np.zeros_like(t)
nonzero_inds = t > 0.0
A = np.array([[-1.0 / tau_syn, 0.0, 0.0], [1.0, -1.0 / tau_syn, 0.0], [0.0, 1.0 / C_m, -1.0 / tau_m]])

expAt = expm(A[None, ...] * t[nonzero_inds, None, None]) # shape (t, 3, 3)
vals[nonzero_inds] = expAt[:, 2, 0] * w * np.e / tau_syn # first two state variables are 0
return vals


def test_resize_recordables():
"""
Test resizing of recordables.
Expand Down Expand Up @@ -80,40 +90,57 @@ def test_simulation_against_analytical_soln():
from multiple different synaptic ports are the same as the analytical solution.
"""

tau_syn = [2.0, 20.0, 60.0, 100.0]
delays = [100.0, 200.0, 500.0, 1200.0]
weight = 1.0
spike_time = 10.0
simtime = 2500.0
tau_syns = [2.0, 20.0, 60.0, 100.0]
delays = [7.0, 5.0, 2.0, 1.0]
weights = [30.0, 50.0, 20.0, 10.0]
C_m = 250.0
tau_m = 15.0
spike_time = 1.0
simtime = 100.0
dt = 1.0

nest.set(resolution=dt)

nrn = nest.Create(
"iaf_psc_alpha_multisynapse",
params={
"C_m": 250.0,
"C_m": C_m,
"E_L": 0.0,
"V_m": 0.0,
"V_th": 1500.0,
"I_e": 0.0,
"tau_m": 15.0,
"tau_syn": tau_syn,
"tau_m": tau_m,
"tau_syn": tau_syns,
},
)
sg = nest.Create("spike_generator", params={"spike_times": [spike_time]})

for i, syn_id in enumerate(range(1, 5)):
syn_spec = {"synapse_model": "static_synapse", "delay": delays[i], "weight": weight, "receptor_type": syn_id}
sg = nest.Create("spike_generator", params={"spike_times": [spike_time]})

for syn_idx, (delay, weight) in enumerate(zip(delays, weights)):
syn_spec = {
"synapse_model": "static_synapse",
"delay": delay,
"weight": weight,
"receptor_type": syn_idx + 1,
}
nest.Connect(sg, nrn, conn_spec="one_to_one", syn_spec=syn_spec)

mm = nest.Create("multimeter", params={"record_from": ["I_syn_1", "I_syn_2", "I_syn_3", "I_syn_4"]})
mm = nest.Create(
"multimeter",
params={"record_from": ["I_syn_1", "I_syn_2", "I_syn_3", "I_syn_4", "V_m", "I_syn"], "interval": dt},
)

nest.Connect(mm, nrn)
nest.Simulate(simtime)
times = mm.get("events", "times")
I_syn = np.sum([mm.get("events", f"I_syn_{i}") for i in range(1, 5)], axis=0)

I_syn_analytical = np.zeros_like(times, dtype=np.float64)
for i in range(4):
I_syn_analytical += alpha_fn(times - delays[i] - spike_time, tau_syn[i])
I_syns_analytical = []
V_m_analytical = np.zeros_like(times)
for weight, delay, tau_s in zip(weights, delays, tau_syns):
I_syns_analytical.append(alpha_fn(times - delay - spike_time, tau_s) * weight)
V_m_analytical += alpha_psc_voltage_response(times - delay - spike_time, tau_s, tau_m, C_m, weight)

for idx, I_syn_analytical in enumerate(I_syns_analytical):
nptest.assert_array_almost_equal(mm.get("events", f"I_syn_{idx+1}"), I_syn_analytical)

nptest.assert_array_almost_equal(I_syn, I_syn_analytical)
nptest.assert_array_almost_equal(mm.get("events", "V_m"), V_m_analytical)

0 comments on commit 09f83d6

Please sign in to comment.