-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Move layer and network-module instantiation to graph-extractor
- Loading branch information
Showing
2 changed files
with
93 additions
and
45 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
# author : Willian Soares Girao | ||
# contact : [email protected] | ||
|
||
from typing import Dict, List, Set, Tuple, Type | ||
from typing import Callable, Dict, List, Optional, Set, Tuple, Type | ||
|
||
import nirtorch | ||
import torch | ||
|
@@ -13,7 +13,10 @@ | |
LAYER_TYPES_WITH_MULTIPLE_INPUTS, | ||
LAYER_TYPES_WITH_MULTIPLE_OUTPUTS, | ||
) | ||
from .dynapcnn_layer_utils import construct_dynapcnnlayers_from_mapper | ||
from .dynapcnnnetwork_module import DynapcnnNetworkModule | ||
from .exceptions import InvalidGraphStructure | ||
from .sinabs_edges_handler import collect_dynapcnn_layer_info | ||
from .utils import Edge, topological_sorting | ||
|
||
|
||
|
@@ -51,11 +54,12 @@ def __init__(self, spiking_model: nn.Module, dummy_input: torch.tensor): | |
spiking_model, dummy_input, model_name=None | ||
).ignore_tensors() | ||
|
||
# converts the NIR representation into a list of edges with nodes represented as integers. | ||
self._edges, self._name_2_indx_map, self._entry_nodes = ( | ||
self._get_edges_from_nir(nir_graph) | ||
) | ||
|
||
# Map node names to indices | ||
self._name_2_indx_map = self._get_name_2_indx_map(nir_graph) | ||
# Extract edges list from graph | ||
self._edges = self._get_edges_from_nir(nir_graph, self._name_2_indx_map) | ||
# Determine entry points to graph | ||
self._entry_nodes = self._get_entry_nodes(self._edges) | ||
# Store the associated `nn.Module` (layer) of each node. | ||
self._indx_2_module_map = self._get_named_modules(spiking_model) | ||
|
||
|
@@ -91,6 +95,47 @@ def sorted_nodes(self) -> List[int]: | |
def indx_2_module_map(self) -> Dict[int, nn.Module]: | ||
return {n: module for n, module in self._indx_2_module_map.items()} | ||
|
||
def get_dynapcnn_network_module( | ||
self, discretize: bool = False, weight_rescaling_fn: Optional[Callable] = None | ||
) -> DynapcnnNetworkModule: | ||
""" Create DynapcnnNetworkModule based on stored graph representation | ||
This includes construction of the DynapcnnLayer instances | ||
Parameters: | ||
----------- | ||
- discretize (bool): If `True`, discretize the parameters and thresholds. This is needed for uploading | ||
weights to dynapcnn. Set to `False` only for testing purposes. | ||
- weight_rescaling_fn (callable): a method that handles how the re-scaling factor for one or more `SumPool2d` projecting to | ||
the same convolutional layer are combined/re-scaled before applying them. | ||
Returns | ||
------- | ||
- The DynapcnnNetworkModule based on graph representation of this `GraphExtractor` | ||
""" | ||
# create a dict holding the data necessary to instantiate a `DynapcnnLayer`. | ||
dcnnl_map = collect_dynapcnn_layer_info( | ||
indx_2_module_map = self.indx_2_module_map, | ||
edges = self.edges, | ||
nodes_io_shapes=self.nodes_io_shapes, | ||
entry_nodes=self.entry_nodes, | ||
) | ||
|
||
# build `DynapcnnLayer` instances from mapper. | ||
dynapcnn_layers, destination_map, entry_points = ( | ||
construct_dynapcnnlayers_from_mapper( | ||
dcnnl_map=dcnnl_map, | ||
discretize=discretize, | ||
rescale_fn=weight_rescaling_fn, | ||
) | ||
) | ||
|
||
# Instantiate the DynapcnnNetworkModule | ||
return DynapcnnNetworkModule( | ||
dynapcnn_layers, destination_map, entry_points | ||
) | ||
|
||
def remove_nodes_by_class(self, node_classes: Tuple[Type]): | ||
"""Remove nodes of given classes from graph in place. | ||
|
@@ -174,40 +219,59 @@ def verify_graph_integrity(self): | |
|
||
####################################################### Pivate Methods ####################################################### | ||
|
||
def _get_name_2_indx_map(self, nir_graph: nirtorch.graph.Graph) -> Dict[str, int]: | ||
"""Assign unique index to each node and return mapper from name to index. | ||
Parameters | ||
---------- | ||
- nir_graph (nirtorch.graph.Graph): a NIR graph representation of `spiking_model`. | ||
Returns | ||
---------- | ||
- name_2_indx_map (dict): `key` is the original variable name for a layer in | ||
`spiking_model` and `value is an integer representing the layer in a standard format. | ||
""" | ||
return { | ||
node.name: node_idx for node_idx, node in enumerate(nir_graph.node_list) | ||
} | ||
|
||
def _get_edges_from_nir( | ||
self, nir_graph: nirtorch.graph.Graph | ||
) -> Tuple[List[Edge], Dict[str, int], List[int]]: | ||
"""Standardize the representation of `nirtorch.graph.Graph` into a list of edges (`Edge`) where each node in `nir_graph` is represented by an interger (with the source node starting as `0`). | ||
self, nir_graph: nirtorch.graph.Graph, name_2_indx_map: Dict[str, int] | ||
) -> Set[Edge]: | ||
"""Standardize the representation of `nirtorch.graph.Graph` into a list of edges, | ||
representing nodes by their indices. | ||
Parameters | ||
---------- | ||
- nir_graph (nirtorch.graph.Graph): a NIR graph representation of `spiking_model`. | ||
- name_2_indx_map (dict): Map from node names to unique indices. | ||
Returns | ||
---------- | ||
- edges (set): tuples describing the connections between layers in `spiking_model`. | ||
- name_2_indx_map (dict): `key` is the original variable name for a layer in `spiking_model` and `value is an integer representing the layer in a standard format. | ||
- entry_nodes (set): IDs of nodes acting as entry points for the network (i.e., receiving external input). | ||
""" | ||
# TODO maybe make sure the input node from nir always gets assined `0`. | ||
|
||
# Assign a unique index to each node | ||
name_2_indx_map = { | ||
node.name: node_idx for node_idx, node in enumerate(nir_graph.node_list) | ||
} | ||
|
||
# Extract edges for each node | ||
edges = { | ||
return { | ||
(name_2_indx_map[src.name], name_2_indx_map[tgt.name]) | ||
for src in nir_graph.node_list | ||
for tgt in src.outgoing_nodes | ||
} | ||
|
||
# find entry nodes of the graph. | ||
all_sources, all_targets = zip(*edges) | ||
entry_nodes = set(all_sources) - set(all_targets) | ||
def _get_entry_nodes(self, edges: Set[Edge]) -> Set[Edge]: | ||
"""Find nodes that act as entry points to the graph | ||
Parameters | ||
---------- | ||
- edges (set): tuples describing the connections between layers in `spiking_model`. | ||
return edges, name_2_indx_map, entry_nodes | ||
Returns | ||
---------- | ||
- entry_nodes (set): IDs of nodes acting as entry points for the network | ||
(i.e., receiving external input). | ||
""" | ||
all_sources, all_targets = zip(*edges) | ||
return set(all_sources) - set(all_targets) | ||
|
||
def _get_named_modules(self, model: nn.Module) -> Dict[int, nn.Module]: | ||
"""Find for each node in the graph what its associated layer in `model` is. | ||
|