Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve speed of dump_connections() in layer_impl.h #3160

Merged
merged 14 commits into from
Jun 24, 2024
Merged
75 changes: 40 additions & 35 deletions nestkernel/layer_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -309,48 +309,53 @@ Layer< D >::dump_connections( std::ostream& out,
DictionaryDatum conn_filter( new Dictionary );
def( conn_filter, names::synapse_model, syn_model );
def( conn_filter, names::target, NodeCollectionDatum( target_layer->get_node_collection() ) );
def( conn_filter, names::source, NodeCollectionDatum( node_collection ) );

// Avoid setting up new array for each iteration of the loop
std::vector< size_t > source_array( 1 );
ArrayDatum connectome = kernel().connection_manager.get_connections( conn_filter );

for ( typename std::vector< std::pair< Position< D >, size_t > >::iterator src_iter = src_vec->begin();
src_iter != src_vec->end();
++src_iter )
// Get variables for loop
size_t previous_source_node_id = getValue< ConnectionDatum >( connectome.get( 0 ) ).get_source_node_id();
Position< D > source_pos = src_vec->begin()->first;

// Print information about all local connections for current source
for ( size_t i = 0; i < connectome.size(); ++i )
{
ConnectionDatum con_id = getValue< ConnectionDatum >( connectome.get( i ) );
const size_t source_node_id = con_id.get_source_node_id();

const size_t source_node_id = src_iter->second;
const Position< D > source_pos = src_iter->first;
// Search source_pos for source node only if it is a different node
if ( source_node_id != previous_source_node_id )
{
source_pos = src_vec->begin()->first;

source_array[ 0 ] = source_node_id;
def( conn_filter, names::source, NodeCollectionDatum( NodeCollection::create( source_array ) ) );
ArrayDatum connectome = kernel().connection_manager.get_connections( conn_filter );
for ( typename std::vector< std::pair< Position< D >, size_t > >::iterator src_iter = src_vec->begin();
( src_iter != src_vec->end() ) && ( source_node_id != src_iter->second );
++src_iter, source_pos = src_iter->first )
;

// Print information about all local connections for current source
for ( size_t i = 0; i < connectome.size(); ++i )
{
ConnectionDatum con_id = getValue< ConnectionDatum >( connectome.get( i ) );
DictionaryDatum result_dict = kernel().connection_manager.get_synapse_status( con_id.get_source_node_id(),
con_id.get_target_node_id(),
con_id.get_target_thread(),
con_id.get_synapse_model_id(),
con_id.get_port() );

long target_node_id = getValue< long >( result_dict, names::target );
double weight = getValue< double >( result_dict, names::weight );
double delay = getValue< double >( result_dict, names::delay );

// Print source, target, weight, delay, rports
out << source_node_id << ' ' << target_node_id << ' ' << weight << ' ' << delay;

Layer< D >* tgt_layer = dynamic_cast< Layer< D >* >( target_layer.get() );

out << ' ';
const long tnode_lid = tgt_layer->node_collection_->get_lid( target_node_id );
assert( tnode_lid >= 0 );
tgt_layer->compute_displacement( source_pos, tnode_lid ).print( out );
out << '\n';
previous_source_node_id = source_node_id;
}
}

DictionaryDatum result_dict = kernel().connection_manager.get_synapse_status( source_node_id,
con_id.get_target_node_id(),
con_id.get_target_thread(),
con_id.get_synapse_model_id(),
con_id.get_port() );

long target_node_id = getValue< long >( result_dict, names::target );
double weight = getValue< double >( result_dict, names::weight );
double delay = getValue< double >( result_dict, names::delay );

// Print source, target, weight, delay, rports
out << source_node_id << ' ' << target_node_id << ' ' << weight << ' ' << delay;

Layer< D >* tgt_layer = dynamic_cast< Layer< D >* >( target_layer.get() );

out << ' ';
const long tnode_lid = tgt_layer->node_collection_->get_lid( target_node_id );
assert( tnode_lid >= 0 );
tgt_layer->compute_displacement( source_pos, tnode_lid ).print( out );
out << '\n';
}
heplesser marked this conversation as resolved.
Show resolved Hide resolved

template < int D >
Expand Down
Loading