Skip to content

Commit

Permalink
Move layer and network-module instantiation to graph-extractor
Browse files Browse the repository at this point in the history
  • Loading branch information
bauerfe committed Oct 15, 2024
1 parent 8c020b2 commit 5082450
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 45 deletions.
30 changes: 7 additions & 23 deletions sinabs/backend/dynapcnn/dynapcnn_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,8 @@

from .chip_factory import ChipFactory
from .dvs_layer import DVSLayer
from .dynapcnn_layer_utils import construct_dynapcnnlayers_from_mapper
from .dynapcnnnetwork_module import DynapcnnNetworkModule
from .io import disable_timestamps, enable_timestamps, open_device, reset_timestamps
from .nir_graph_extractor import GraphExtractor
from .sinabs_edges_handler import collect_dynapcnn_layer_info
from .utils import (
DEFAULT_IGNORED_LAYER_TYPES,
parse_device_id,
Expand Down Expand Up @@ -83,28 +80,11 @@ def __init__(
# Remove nodes of ignored classes (including merge nodes)
self._graph_extractor.remove_nodes_by_class(DEFAULT_IGNORED_LAYER_TYPES)

# create a dict holding the data necessary to instantiate a `DynapcnnLayer`.
self._dcnnl_map = collect_dynapcnn_layer_info(
self._graph_extractor.indx_2_module_map,
self._graph_extractor.edges,
self._graph_extractor.nodes_io_shapes,
self._graph_extractor.entry_nodes,
)

# build `DynapcnnLayer` instances from mapper.
dynapcnn_layers, destination_map, entry_points = (
construct_dynapcnnlayers_from_mapper(
dcnnl_map=self._dcnnl_map,
discretize=discretize,
rescale_fn=weight_rescaling_fn,
)
)

# Module to execute forward pass through network
self._dynapcnn_module = DynapcnnNetworkModule(
dynapcnn_layers, destination_map, entry_points
self._dynapcnn_module = self._graph_extractor.get_dynapcnn_network_module(
discretize=discretize, weight_rescaling_fn=weight_rescaling_fn
)
self.dynapcnn_module.setup_dynapcnnlayer_graph(index_layers_topologically=True)
self._dynapcnn_module.setup_dynapcnnlayer_graph(index_layers_topologically=True)

####################################################### Public Methods #######################################################

Expand All @@ -116,6 +96,10 @@ def dynapcnn_layers(self):
def dynapcnn_module(self):
return self._dynapcnn_module

@property
def layer_destination_map(self):
return self._dynapcnn_module.destination_map

@property
def chip_layers_ordering(self):
return self._chip_layers_ordering
Expand Down
108 changes: 86 additions & 22 deletions sinabs/backend/dynapcnn/nir_graph_extractor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# author : Willian Soares Girao
# contact : [email protected]

from typing import Dict, List, Set, Tuple, Type
from typing import Callable, Dict, List, Optional, Set, Tuple, Type

import nirtorch
import torch
Expand All @@ -13,7 +13,10 @@
LAYER_TYPES_WITH_MULTIPLE_INPUTS,
LAYER_TYPES_WITH_MULTIPLE_OUTPUTS,
)
from .dynapcnn_layer_utils import construct_dynapcnnlayers_from_mapper
from .dynapcnnnetwork_module import DynapcnnNetworkModule
from .exceptions import InvalidGraphStructure
from .sinabs_edges_handler import collect_dynapcnn_layer_info
from .utils import Edge, topological_sorting


Expand Down Expand Up @@ -51,11 +54,12 @@ def __init__(self, spiking_model: nn.Module, dummy_input: torch.tensor):
spiking_model, dummy_input, model_name=None
).ignore_tensors()

# converts the NIR representation into a list of edges with nodes represented as integers.
self._edges, self._name_2_indx_map, self._entry_nodes = (
self._get_edges_from_nir(nir_graph)
)

# Map node names to indices
self._name_2_indx_map = self._get_name_2_indx_map(nir_graph)
# Extract edges list from graph
self._edges = self._get_edges_from_nir(nir_graph, self._name_2_indx_map)
# Determine entry points to graph
self._entry_nodes = self._get_entry_nodes(self._edges)
# Store the associated `nn.Module` (layer) of each node.
self._indx_2_module_map = self._get_named_modules(spiking_model)

Expand Down Expand Up @@ -91,6 +95,47 @@ def sorted_nodes(self) -> List[int]:
def indx_2_module_map(self) -> Dict[int, nn.Module]:
return {n: module for n, module in self._indx_2_module_map.items()}

def get_dynapcnn_network_module(
self, discretize: bool = False, weight_rescaling_fn: Optional[Callable] = None
) -> DynapcnnNetworkModule:
""" Create DynapcnnNetworkModule based on stored graph representation
This includes construction of the DynapcnnLayer instances
Parameters:
-----------
- discretize (bool): If `True`, discretize the parameters and thresholds. This is needed for uploading
weights to dynapcnn. Set to `False` only for testing purposes.
- weight_rescaling_fn (callable): a method that handles how the re-scaling factor for one or more `SumPool2d` projecting to
the same convolutional layer are combined/re-scaled before applying them.
Returns
-------
- The DynapcnnNetworkModule based on graph representation of this `GraphExtractor`
"""
# create a dict holding the data necessary to instantiate a `DynapcnnLayer`.
dcnnl_map = collect_dynapcnn_layer_info(
indx_2_module_map = self.indx_2_module_map,
edges = self.edges,
nodes_io_shapes=self.nodes_io_shapes,
entry_nodes=self.entry_nodes,
)

# build `DynapcnnLayer` instances from mapper.
dynapcnn_layers, destination_map, entry_points = (
construct_dynapcnnlayers_from_mapper(
dcnnl_map=dcnnl_map,
discretize=discretize,
rescale_fn=weight_rescaling_fn,
)
)

# Instantiate the DynapcnnNetworkModule
return DynapcnnNetworkModule(
dynapcnn_layers, destination_map, entry_points
)

def remove_nodes_by_class(self, node_classes: Tuple[Type]):
"""Remove nodes of given classes from graph in place.
Expand Down Expand Up @@ -174,40 +219,59 @@ def verify_graph_integrity(self):

####################################################### Pivate Methods #######################################################

def _get_name_2_indx_map(self, nir_graph: nirtorch.graph.Graph) -> Dict[str, int]:
"""Assign unique index to each node and return mapper from name to index.
Parameters
----------
- nir_graph (nirtorch.graph.Graph): a NIR graph representation of `spiking_model`.
Returns
----------
- name_2_indx_map (dict): `key` is the original variable name for a layer in
`spiking_model` and `value is an integer representing the layer in a standard format.
"""
return {
node.name: node_idx for node_idx, node in enumerate(nir_graph.node_list)
}

def _get_edges_from_nir(
self, nir_graph: nirtorch.graph.Graph
) -> Tuple[List[Edge], Dict[str, int], List[int]]:
"""Standardize the representation of `nirtorch.graph.Graph` into a list of edges (`Edge`) where each node in `nir_graph` is represented by an interger (with the source node starting as `0`).
self, nir_graph: nirtorch.graph.Graph, name_2_indx_map: Dict[str, int]
) -> Set[Edge]:
"""Standardize the representation of `nirtorch.graph.Graph` into a list of edges,
representing nodes by their indices.
Parameters
----------
- nir_graph (nirtorch.graph.Graph): a NIR graph representation of `spiking_model`.
- name_2_indx_map (dict): Map from node names to unique indices.
Returns
----------
- edges (set): tuples describing the connections between layers in `spiking_model`.
- name_2_indx_map (dict): `key` is the original variable name for a layer in `spiking_model` and `value is an integer representing the layer in a standard format.
- entry_nodes (set): IDs of nodes acting as entry points for the network (i.e., receiving external input).
"""
# TODO maybe make sure the input node from nir always gets assined `0`.

# Assign a unique index to each node
name_2_indx_map = {
node.name: node_idx for node_idx, node in enumerate(nir_graph.node_list)
}

# Extract edges for each node
edges = {
return {
(name_2_indx_map[src.name], name_2_indx_map[tgt.name])
for src in nir_graph.node_list
for tgt in src.outgoing_nodes
}

# find entry nodes of the graph.
all_sources, all_targets = zip(*edges)
entry_nodes = set(all_sources) - set(all_targets)
def _get_entry_nodes(self, edges: Set[Edge]) -> Set[Edge]:
"""Find nodes that act as entry points to the graph
Parameters
----------
- edges (set): tuples describing the connections between layers in `spiking_model`.
return edges, name_2_indx_map, entry_nodes
Returns
----------
- entry_nodes (set): IDs of nodes acting as entry points for the network
(i.e., receiving external input).
"""
all_sources, all_targets = zip(*edges)
return set(all_sources) - set(all_targets)

def _get_named_modules(self, model: nn.Module) -> Dict[int, nn.Module]:
"""Find for each node in the graph what its associated layer in `model` is.
Expand Down

0 comments on commit 5082450

Please sign in to comment.