Skip to content

Commit

Permalink
WIP - DVS node not given
Browse files Browse the repository at this point in the history
constructor of GraphExtractor executing without errors
  • Loading branch information
Willian-Girao committed Oct 29, 2024
1 parent beade18 commit 89200d9
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 23 deletions.
8 changes: 4 additions & 4 deletions sinabs/backend/dynapcnn/connectivity_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
Pooling = (sl.SumPool2d, nn.AvgPool2d)
Weight = (nn.Conv2d, nn.Linear)
Neuron = (sl.IAFSqueeze,)
DVS = (DVSLayer, Crop2d, FlipDims)
DVS = (DVSLayer, )

# @TODO - need to list other edge cases involving DVS layer (for now only dvs-weight and dvs-pooling).
VALID_SINABS_EDGE_TYPES_ABSTRACT = {
Expand All @@ -29,9 +29,9 @@
# Pooling can be followed by weight layer of next core
(Pooling, Weight): "pooling-weight",
# Dvs can be followed by weight layer of next core
(DVSLayer, Weight): "dvs-weight",
(DVS, Weight): "dvs-weight",
# Dvs can be followed by pooling layers
(DVSLayer, Pooling): "dvs-pooling",
(DVS, Pooling): "dvs-pooling",
}

# Unpack dict
Expand All @@ -46,4 +46,4 @@
LAYER_TYPES_WITH_MULTIPLE_INPUTS = Union[sl.Merge]

# Neuron and pooling layers can have their output sent to multiple cores
LAYER_TYPES_WITH_MULTIPLE_OUTPUTS = Union[(*Neuron, *Pooling, DVSLayer)]
LAYER_TYPES_WITH_MULTIPLE_OUTPUTS = Union[(*Neuron, *Pooling, *DVS)]
25 changes: 12 additions & 13 deletions sinabs/backend/dynapcnn/nir_graph_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,23 +66,18 @@ def __init__(self, spiking_model: nn.Module, dummy_input: torch.tensor, dvs_inpu
# Store the associated `nn.Module` (layer) of each node.
self._indx_2_module_map = self._get_named_modules(spiking_model)

# True if `dvs_input == True` and `spiking_model` does not start with DVS layer.
# Determine entry points to graph
self._entry_nodes = self._get_entry_nodes(self._edges)

# If DVS camera is wanted but `spiking_model` does not start with DVS layer.
if self._need_dvs_node(spiking_model, dvs_input):
# input shape for `DVSLayer` instance that will be the module of the node 'dvs'.
_, _, height, width = dummy_input.shape
self._add_dvs_node(dvs_input_shape=(height, width))
# Consolidates the edges associated with a DVSLayer instance (ie., fix NIR edges extraction when DVS is a node in the graph).
fix_dvs_module_edges(edges=self._edges, indx_2_module_map=self._indx_2_module_map)

# Determine entry points to graph
self._entry_nodes = self._get_entry_nodes(self._edges)

print('----------------------------------------')
print(self._edges)
for key, val in self._name_2_indx_map.items():
print(key, val)
print('----------------------------------------')

# Consolidates the edges associated with a DVSLayer instance.
fix_dvs_module_edges(edges=self._edges, indx_2_module_map=self._indx_2_module_map)
self._entry_nodes

# Verify that graph is compatible
self.verify_graph_integrity()
Expand Down Expand Up @@ -247,7 +242,9 @@ def verify_graph_integrity(self):

def _add_dvs_node(self, dvs_input_shape: Tuple[int, int]) -> None:
""" In-place modification of `self._name_2_indx_map`, `self._indx_2_module_map`, and `self._edges` to accomodate the
creation of an extra node in the graph representing the DVS camera of the chip.
creation of an extra node in the graph representing the DVS camera of the chip. The DVSLayer node will point to every
other node that is up to this point an entry node of the original graph, so `self._entry_nodes` is modified in-place
to have only one entry: the index of the DVS node.
Parameters
----------
Expand All @@ -263,6 +260,8 @@ def _add_dvs_node(self, dvs_input_shape: Tuple[int, int]) -> None:
self._indx_2_module_map[self._name_2_indx_map['dvs']] = DVSLayer(input_shape=dvs_input_shape)
# set DVS node as input to each entry node of the graph.
self._edges.update({(self._name_2_indx_map['dvs'], entry_node) for entry_node in self._entry_nodes})
# DVSLayer node becomes the only entrypoint of the graph.
self._entry_nodes = {self._name_2_indx_map['dvs']}

def _need_dvs_node(self, model: nn.Module, dvs_input: bool) -> bool:
""" Returns whether or not a node will need to be added to represent a `DVSLayer` instance. A new node will have
Expand Down
23 changes: 17 additions & 6 deletions sinabs/backend/dynapcnn/sinabs_edges_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from torch import Size, nn

from .connectivity_specs import VALID_SINABS_EDGE_TYPES, DVS
from .connectivity_specs import VALID_SINABS_EDGE_TYPES
from .exceptions import (
InvalidEdge,
InvalidGraphStructure,
Expand Down Expand Up @@ -46,26 +46,32 @@ def fix_dvs_module_edges(edges: Set[Edge], indx_2_module_map: Dict[int, nn.Modul
feeds that in the sequence `DVS.pool -> DVS.crop -> DVS.flip`, so we remove edges involving these nodes that are internaly
implementend in the DVSLayer instance from the graph and point the node of DVSLayer directly to the layer/module it is suppoed
to forward its data to.
Modifies `indx_2_module_map` in-place to remove the nodes (Crop2d, FlipDims and DVSLayer's pooling) defined within the DVSLayer
instance since those are not independent nodes of the final graph.
Currently, this is also removing a self-recurrent node with edge `(FlipDims, FlipDims)` that is
created when forwarding via DVSLayer.
The 'fix_' is to imply there's something odd with the extracted adges for the forward pass implemented by
the DVSLayer. For now this function is fixing these edges to have them representing the information flow through
this layer as **it should be** but the graph tracing of NIR should be looked into to solve the root problem.
Parameters
----------
- edges (set): tuples describing the connections between layers in `spiking_model`.
- indx_2_module_map (dict): the mapping between a node (`key` as an `int`) and its module (`value` as a `nn.Module`).
"""
# TODO - the 'fix_' is to imply there's something odd with the extracted adges for the forward pass implemented by
# the DVSLayer. For now this function is fixing these edges to have them representing the information flow through
# this layer as **it should be** but the graph tracing of NIR should be looked into to solve the root problem.

# spot nodes (ie, modules) used in a DVSLayer instance's forward pass (including the DVSLayer node itself).
dvslayer_nodes = {
index: module for index, module in indx_2_module_map.items()
if any(isinstance(module, dvs_node) for dvs_node in DVS)
if any(isinstance(module, dvs_node) for dvs_node in (DVSLayer, Crop2d, FlipDims))
}

if len(dvslayer_nodes) == 1:
# No module within the DVSLayer instance appears as an independent node - nothing to do here.
return

# TODO - a `SumPool2d` is also a node that's used inside a DVSLayer instance. In what follows we try to find it
# by looking for pooling nodes that appear in a (pool, crop) edge - the assumption being that if the pooling is
# inputing into a crop layer than the pool is inside the DVSLayer instance. It feels like a hacky way to do it
Expand Down Expand Up @@ -94,6 +100,11 @@ def fix_dvs_module_edges(edges: Set[Edge], indx_2_module_map: Dict[int, nn.Modul
if any(len(node) > 1 for node in [dvs_node, dvs_pool_node, dvs_crop_node, dvs_flip_node]):
raise ValueError(f'Internal DVS nodes should be single instances but multiple have been found: dvs_node: {len(dvs_node)} dvs_pool_node: {len(dvs_pool_node)} dvs_crop_node: {len(dvs_crop_node)} dvs_flip_node: {len(dvs_flip_node)}')

# Remove dvs_pool, dvs_crop and dvs_flip nodes from `indx_2_module_map` (these operate within the DVS, not as independent nodes of the final graph).
indx_2_module_map.pop(dvs_pool_node[-1])
indx_2_module_map.pop(dvs_crop_node[-1])
indx_2_module_map.pop(dvs_flip_node[-1])

# dvs_pool, dvs_crop and dvs_flip are internal nodes of the DVSLayer: we only want an edge from 'dvs' node to the entry points of the network.
edges.update({(dvs_node[-1], node) for node in entry_nodes})

Expand Down

0 comments on commit 89200d9

Please sign in to comment.