Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/master' into integrate_specifi…
Browse files Browse the repository at this point in the history
…c_odes
  • Loading branch information
C.A.P. Linssen committed Mar 17, 2023
2 parents 10a1042 + 291daa1 commit f87f135
Show file tree
Hide file tree
Showing 20 changed files with 218 additions and 88 deletions.
11 changes: 11 additions & 0 deletions doc/running.rst
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,17 @@ NEST Simulator target

After NESTML completes, the NEST extension module (by default called ``"nestmlmodule"``) can either be statically linked into NEST (see `Writing an extension module <https://nest-extension-module.readthedocs.io/>`_), or loaded dynamically using the ``Install`` API call in Python.


Simulation loop
~~~~~~~~~~~~~~~

At the beginning of each timestep, incoming spikes become visible in those variables that correspond to a convolution with the corresponding spiking input port.

Then, the code is run corresponding to the NESTML ``update`` block.

At the end of the timestep, variables corresponding to convolutions are updated according to their ODE dynamics.


Code generation options
~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
16 changes: 0 additions & 16 deletions models/neurons/iaf_psc_delta.nestml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ iaf_psc_alpha, iaf_psc_exp
neuron iaf_psc_delta:

state:
refr_spikes_buffer mV = 0 mV
r integer = 0 # Counts number of tick during the refractory period
V_m mV = E_L # Membrane potential
end
Expand All @@ -71,7 +70,6 @@ neuron iaf_psc_delta:
V_reset mV = -70 mV # Reset potential of the membrane
V_th mV = -55 mV # Spike threshold
V_min mV = -inf * 1 mV # Absolute lower value for the membrane potential
with_refr_input boolean = false # If true, do not discard input during refractory period.

# constant external input current
I_e pA = 0 pA
Expand All @@ -93,23 +91,9 @@ neuron iaf_psc_delta:
if r == 0: # neuron not refractory
integrate_odes()

# if we have accumulated spikes from refractory period,
# add and reset accumulator
if with_refr_input and refr_spikes_buffer != 0.0 mV:
V_m += refr_spikes_buffer
refr_spikes_buffer = 0.0 mV
end

# lower bound of membrane potential
V_m = V_m < V_min ? V_min : V_m

else: # neuron is absolute refractory
# read spikes from buffer and accumulate them, discounting
# for decay until end of refractory period
# the buffer is clear automatically
if with_refr_input:
refr_spikes_buffer += spikes * exp(-r * h / tau_m) * mV/pA
end
r -= 1
end

Expand Down
2 changes: 1 addition & 1 deletion models/synapses/static_synapse.nestml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Description
+++++++++++
A synapse where the synaptic strength (weight) does not evolve with simulated time, but is defined as a (constant) parameter.
"""
synapse static:
synapse static_synapse:

parameters:
w real = 900 @nest::weight @homogeneous
Expand Down
18 changes: 14 additions & 4 deletions pynestml/codegeneration/printers/latex_function_call_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from pynestml.codegeneration.printers.function_call_printer import FunctionCallPrinter
from pynestml.meta_model.ast_function_call import ASTFunctionCall
from pynestml.meta_model.ast_node import ASTNode
from pynestml.symbols.predefined_functions import PredefinedFunctions
from pynestml.utils.ast_utils import ASTUtils

Expand All @@ -32,6 +33,11 @@ class LatexFunctionCallPrinter(FunctionCallPrinter):
Printer for ASTFunctionCall in LaTeX syntax.
"""

def print(self, node: ASTNode) -> str:
assert isinstance(node, ASTFunctionCall)

return self.print_function_call(node)

def _print_function_call(self, node: ASTFunctionCall) -> str:
r"""
Converts a single handed over function call to C++ NEST API syntax.
Expand Down Expand Up @@ -70,15 +76,19 @@ def _print_function_call(self, node: ASTFunctionCall) -> str:
return r"\text{" + function_name + r"}"

def print_function_call(self, function_call: ASTFunctionCall) -> str:
function_name = self._print_function_name(function_call)
result = self._print_function_call(function_call)

if ASTUtils.needs_arguments(function_call):
return function_name % self._print_function_call_argument_list(function_call)
n_args = len(function_call.get_args())
result += '(' + ', '.join(['%s' for _ in range(n_args)]) + ')'
else:
result += '()'

return function_name
return result % self._print_function_call_argument_list(function_call)

def _print_function_call_argument_list(self, function_call: ASTFunctionCall) -> Tuple[str, ...]:
ret = []
for arg in function_call.get_args():
ret.append(self.print_expression(arg))
ret.append(self._expression_printer.print(arg))

return tuple(ret)
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,6 @@ def print_simple_expression(self, node: ASTSimpleExpression) -> str:
return self._variable_printer.print_variable(node.get_variable())

if node.is_function_call():
return self.print_function_call(node.get_function_call())
return self._function_call_printer.print(node.get_function_call())

raise Exception("Unknown node type")
3 changes: 0 additions & 3 deletions pynestml/codegeneration/printers/nestml_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,9 +479,6 @@ def print_simple_expression(self, node: ASTSimpleExpression) -> str:
return str(node.numeric_literal)

if node.is_variable():
print("PRINTING " +node.get_variable().get_complete_name())
if node.get_variable().get_name() == "I_dend":
import pdb;pdb.set_trace()
return self.print_variable(node.get_variable())

if node.is_string():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -517,22 +517,15 @@ void {{neuronName}}::update(nest::Time const & origin,const long from, const lon
{% if propagators_are_state_dependent %}
// the propagators are state dependent; update them!
recompute_internal_variables();

{% endif %}
for ( long lag = from ; lag < to ; ++lag )
{
for (long i = 0; i < NUM_SPIKE_RECEPTORS; ++i)
{
get_spike_inputs_grid_sum_()[i] = get_spike_inputs_()[i].get_value(lag);
}

// process spikes from buffers
// process inputs from continuous-type input ports
{%- for inputPort in neuron.get_continuous_input_ports() %}
B_.{{ inputPort.name }}_grid_sum_ = get_{{ inputPort.name }}().get_value(lag);
{% endfor %}
{%- filter indent(4) %}
{%- include "directives/ApplySpikesFromBuffers.jinja2" %}

{% endfilter %}
{% endfor %}
{%- if has_delay_variables %}
// delay variables

Expand All @@ -554,16 +547,25 @@ void {{neuronName}}::update(nest::Time const & origin,const long from, const lon
{%- endfor %}
{%- endfilter %}
{%- endif %}

{%- if analytic_state_variables_from_convolutions | length > 0 %}
{%- set analytic_state_variables_ = analytic_state_variables_from_convolutions %}
// integrate variables related to convolutions

{%- filter indent(4) %}
// integrate variables related to convolutions
{% filter indent(4) %}
{%- include "directives/AnalyticIntegrationStep_begin.jinja2" %}
{%- include "directives/AnalyticIntegrationStep_end.jinja2" %}
{%- endfilter %}
{% endif %}
{%- endif %}

// process spikes from buffers

for (long i = 0; i < NUM_SPIKE_RECEPTORS; ++i)
{
get_spike_inputs_grid_sum_()[i] = get_spike_inputs_()[i].get_value(lag);
}
{%- filter indent(4) %}
{%- include "directives/ApplySpikesFromBuffers.jinja2" %}
{% endfilter %}

// voltage logging
B_.logger_.record_data(origin.get_steps() + lag);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ public:
{%- set isHomogeneous = PyNestMLLexer["DECORATOR_HOMOGENEOUS"] in parameter.get_decorators() %}
{%- if isHomogeneous %}
{%- with variable = utils.get_parameter_variable_by_name(astnode, parameter.get_symbol_name()) %}
{%- set variable_symbol = synapse.get_scope().resolve_to_symbol(variable.get_complete_name(), SymbolKind.VARIABLE) %}
{%- include "directives/CommonPropertiesDictionaryMemberInitialization.jinja2" %}
{%- endwith %}
{%- endif %}
Expand All @@ -132,9 +133,9 @@ public:
{%- if namespaceName == '' %}
{{ raise('nest::names decorator is required for parameter "%s" when used in a common properties class' % printer.print(utils.get_parameter_variable_by_name(astnode, parameter.get_symbol_name()))) }}
{%- endif %}
{%- with variable = parameter %}
{%- include "directives/CommonPropertiesDictionaryWriter.jinja2" %}
{%- endwith %}
{%- set variable_symbol = parameter %}
{%- set variable = utils.get_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
{%- include "directives/CommonPropertiesDictionaryWriter.jinja2" %}
{%- endif %}
{%- endfor %}
{%- endfilter %}
Expand All @@ -156,9 +157,9 @@ public:
{%- if (namespaceName == '') %}
{{ raise('nest::names decorator is required for parameter "%s" when used in a common properties class' % printer.print(utils.get_parameter_variable_by_name(astnode, parameter.get_symbol_name()))) }}
{%- endif %}
{%- with variable = parameter %}
{%- include "directives/CommonPropertiesDictionaryReader.jinja2" %}
{%- endwith %}
{%- set variable_symbol = parameter %}
{%- set variable = utils.get_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
{%- include "directives/CommonPropertiesDictionaryReader.jinja2" %}
{%- endif %}
{%- endfor %}
{%- endfilter %}
Expand Down Expand Up @@ -595,7 +596,6 @@ public:
return tid;
};

// synapse STDP depressing/facilitation dynamics
const double __t_spike = e.get_stamp().get_ms();
#ifdef DEBUG
std::cout << "{{synapseName}}::send(): handling pre spike at t = " << __t_spike << std::endl;
Expand Down Expand Up @@ -681,9 +681,9 @@ public:
{%- filter indent(6, True) %}
{%- if post_ports is defined %}
{%- for post_port in spiking_post_ports %}
/**
* NESTML generated onReceive code block for postsynaptic port "{{post_port}}" begins here!
**/
/**
* NESTML generated onReceive code block for postsynaptic port "{{post_port}}" begins here!
**/

{%- set dynamics = synapse.get_on_receive_block(post_port) %}
{%- with ast = dynamics.get_block() %}
Expand Down Expand Up @@ -715,11 +715,6 @@ public:

const double _tr_t = __t_spike - __dendritic_delay;

#ifdef DEBUG
std::cout << "\tDepressing, old w = " << S_.w << "\n";
#endif
//std::cout << "r2 = " << get_tr_r2() << std::endl;

{%- filter indent(4, True) %}
{%- for pre_port in pre_ports %}
/**
Expand All @@ -733,10 +728,6 @@ public:
{%- endfor %}
{%- endfilter %}

#ifdef DEBUG
std::cout <<"\t-> new w = " << S_.w << std::endl;
#endif

/**
* update all convolutions with pre spikes
**/
Expand All @@ -761,10 +752,8 @@ public:
/**
* NESTML generated onReceive code block for postsynaptic port "{{post_port}}" begins here!
**/
#ifdef DEBUG
std::cout << "\tFacilitating from c = " << S_.c << " (using trace = " << S_.pre_tr << ")";
#endif
{%- set dynamics = synapse.get_on_receive_block(post_port) %}

{% set dynamics = synapse.get_on_receive_block(post_port) %}
{%- with ast = dynamics.get_block() %}
{%- include "directives/Block.jinja2" %}
{%- endwith %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
@param variable VariableSymbol Variable for which the initialization should be done
#}
{%- if tracing %}/* generated by {{self._TemplateReference__context.name}} */ {% endif -%}
{%- if variable.has_declaring_expression() and not variable.is_kernel() %}
{%- if variable.has_vector_parameter() %}
this->{{ printer.print(utils.get_state_variable_by_name(astnode, variable.get_symbol_name())) }}.resize(P_.{{ variable.get_vector_parameter() }}, {{ printer.print_expression(variable.get_declaring_expression()) }}); // as {{ variable.get_type_symbol().print_symbol() }}
{%- if variable_symbol.has_declaring_expression() and not variable_symbol.is_kernel() %}
{%- if variable_symbol.has_vector_parameter() %}
this->{{ printer_no_origin.print(variable) }}.resize(P_.{{ variable_symbol.get_vector_parameter() }}, {{ printer.print_expression(variable_symbol.get_declaring_expression()) }}); // as {{ variable_symbol.get_type_symbol().print_symbol() }}
{%- else %}
this->{{ printer.print(utils.get_state_variable_by_name(astnode, variable.get_symbol_name())) }} = {{ printer.print_expression(variable.get_declaring_expression()) }}; // as {{ variable.get_type_symbol().print_symbol() }}
this->{{ printer_no_origin.print(variable) }} = {{ printer.print(variable_symbol.get_declaring_expression()) }}; // as {{ variable_symbol.get_type_symbol().print_symbol() }}
{%- endif %}
{%- else %}
{%- if variable.has_vector_parameter() %}
this->{{ printer.print(utils.get_state_variable_by_name(astnode, variable.get_symbol_name())) }}.resize(0); // as {{ variable.get_type_symbol().print_symbol() }}
{%- if variable_symbol.has_vector_parameter() %}
this->{{ printer_no_origin.print(variable) }}.resize(0); // as {{ variable_symbol.get_type_symbol().print_symbol() }}
{%- else %}
this->{{ printer.print(utils.get_state_variable_by_name(astnode, variable.get_symbol_name())) }} = 0; // as {{ variable.get_type_symbol().print_symbol() }}
this->{{ printer_no_origin.print(variable) }} = 0; // as {{ variable_symbol.get_type_symbol().print_symbol() }}
{%- endif %}
{%- endif %}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{#
In general case creates an
@param variable VariableSymbol Variable for which the initialization should be done
#}
{%- if tracing %}/* generated by {{self._TemplateReference__context.name}} */ {% endif %}
{%- if variable_symbol.has_vector_parameter() %}
{{ raise('Vector parameters not supported in common properties dictionary.') }}
{%- endif %}
updateValue< {{ declarations.print_variable_type(variable_symbol) }} >(d, names::{{namespaceName}}, this->{{ printer_no_origin.print(variable) }} );
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
@param variable VariableSymbol Variable for which the initialization should be done
#}
{%- if tracing %}/* generated by {{self._TemplateReference__context.name}} */ {% endif %}
{%- if variable.has_vector_parameter() %}
{%- if variable_symbol.has_vector_parameter() %}
{{ raise('Vector parameters not supported in common properties dictionary.') }}
{%- endif %}
def< {{ declarations.print_variable_type(variable) }} >(d, names::{{ namespaceName }}, this->{{ printer.print(utils.get_state_variable_by_name(astnode, variable.get_symbol_name())) }} );
def< {{ declarations.print_variable_type(variable_symbol) }} >(d, names::{{ namespaceName }}, this->{{ printer_no_origin.print(variable) }} );
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#}
{%- if tracing %}/* generated by {{self._TemplateReference__context.name}} */ {% endif %}
{%- if utils.is_integrate(ast) %}

{%- if ast.get_args() | length > 0 %}
{%- set analytic_state_variables_ = utils.filter_variables_list(analytic_state_variables_except_convolutions, ast.get_args()) %}
{%- else %}
Expand All @@ -13,10 +14,26 @@
{%- include "directives/AnalyticIntegrationStep_begin.jinja2" %}

{%- if uses_numeric_solver %}

// solver step should update state of convolutions internally, but not change ode_state[] pertaining to convolutions; convolution integration should be independent of integrate_odes() calls
// buffer the old values
{%- for variable_name in analytic_state_variables_from_convolutions %}
{%- set update_expr = update_expressions[variable_name] %}
{%- set variable_symbol = variable_symbols[variable_name] %}
const double {{ variable_name }}__orig = {{ printer.print(utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name())) }};
{%- endfor %}

{%- include "directives/GSLIntegrationStep.jinja2" %}

// restore the old values for convolutions
{%- for variable_name in analytic_state_variables_from_convolutions %}
{%- set variable_symbol = variable_symbols[variable_name] %}
{{ printer.print(utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name())) }} = {{ variable_name }}__orig;
{%- endfor %}

{%- endif %}

{%- include "directives/AnalyticIntegrationStep_end.jinja2" %}
{%- else %}
{{printer.print(ast)}};
{{ printer.print(ast) }};
{%- endif %}
8 changes: 7 additions & 1 deletion pynestml/utils/ast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1890,10 +1890,12 @@ def generate_kernel_buffers_(cls, neuron: ASTNeuron, equations_block: Union[ASTE
return kernel_buffers

@classmethod
def replace_convolution_aliasing_inlines(cls, neuron: ASTNeuron) -> None:
def replace_convolution_aliasing_inlines(cls, neuron: ASTNeuron) -> Dict[str, str]:
"""
Replace all occurrences of kernel names (e.g. ``I_dend`` and ``I_dend'`` for a definition involving a second-order kernel ``inline kernel I_dend = convolve(kern_name, spike_buf)``) with the ODE-toolbox generated variable ``kern_name__X__spike_buf``.
"""
aliases: Dict[str, str] = {}

def replace_var(_expr, replace_var_name: str, replace_with_var_name: str):
if isinstance(_expr, ASTSimpleExpression) and _expr.is_variable():
var = _expr.get_variable()
Expand All @@ -1902,12 +1904,14 @@ def replace_var(_expr, replace_var_name: str, replace_with_var_name: str):
differential_order=0)
ast_variable.set_source_position(var.get_source_position())
_expr.set_variable(ast_variable)
aliases[var.get_name()] = replace_with_var_name + '__d' * var.get_differential_order()

elif isinstance(_expr, ASTVariable):
var = _expr
if var.get_name() == replace_var_name:
var.set_name(replace_with_var_name + '__d' * var.get_differential_order())
var.set_differential_order(0)
aliases[var.get_name()] = replace_with_var_name + '__d' * var.get_differential_order()

for equation_block in neuron.get_equations_blocks():
for decl in equation_block.get_declarations():
Expand All @@ -1917,6 +1921,8 @@ def replace_var(_expr, replace_var_name: str, replace_with_var_name: str):
replace_with_var_name = decl.get_expression().get_variable().get_name()
neuron.accept(ASTHigherOrderVisitor(lambda x: replace_var(x, decl.get_variable_name(), replace_with_var_name)))

return aliases

@classmethod
def replace_variable_names_in_expressions(cls, neuron: ASTNeuron, solver_dicts: List[dict]) -> None:
"""
Expand Down
Loading

0 comments on commit f87f135

Please sign in to comment.