diff --git a/sinabs/backend/dynapcnn/connectivity_specs.py b/sinabs/backend/dynapcnn/connectivity_specs.py index 13e17911..c361f240 100644 --- a/sinabs/backend/dynapcnn/connectivity_specs.py +++ b/sinabs/backend/dynapcnn/connectivity_specs.py @@ -14,7 +14,7 @@ Pooling = (sl.SumPool2d, nn.AvgPool2d) Weight = (nn.Conv2d, nn.Linear) Neuron = (sl.IAFSqueeze,) -DVS = (DVSLayer, Crop2d, FlipDims) +DVS = (DVSLayer, ) # @TODO - need to list other edge cases involving DVS layer (for now only dvs-weight and dvs-pooling). VALID_SINABS_EDGE_TYPES_ABSTRACT = { @@ -29,9 +29,9 @@ # Pooling can be followed by weight layer of next core (Pooling, Weight): "pooling-weight", # Dvs can be followed by weight layer of next core - (DVSLayer, Weight): "dvs-weight", + (DVS, Weight): "dvs-weight", # Dvs can be followed by pooling layers - (DVSLayer, Pooling): "dvs-pooling", + (DVS, Pooling): "dvs-pooling", } # Unpack dict @@ -46,4 +46,4 @@ LAYER_TYPES_WITH_MULTIPLE_INPUTS = Union[sl.Merge] # Neuron and pooling layers can have their output sent to multiple cores -LAYER_TYPES_WITH_MULTIPLE_OUTPUTS = Union[(*Neuron, *Pooling, DVSLayer)] +LAYER_TYPES_WITH_MULTIPLE_OUTPUTS = Union[(*Neuron, *Pooling, *DVS)] diff --git a/sinabs/backend/dynapcnn/nir_graph_extractor.py b/sinabs/backend/dynapcnn/nir_graph_extractor.py index fe984855..0a80bec0 100644 --- a/sinabs/backend/dynapcnn/nir_graph_extractor.py +++ b/sinabs/backend/dynapcnn/nir_graph_extractor.py @@ -66,23 +66,18 @@ def __init__(self, spiking_model: nn.Module, dummy_input: torch.tensor, dvs_inpu # Store the associated `nn.Module` (layer) of each node. self._indx_2_module_map = self._get_named_modules(spiking_model) - # True if `dvs_input == True` and `spiking_model` does not start with DVS layer. + # Determine entry points to graph + self._entry_nodes = self._get_entry_nodes(self._edges) + + # If DVS camera is wanted but `spiking_model` does not start with DVS layer. if self._need_dvs_node(spiking_model, dvs_input): # input shape for `DVSLayer` instance that will be the module of the node 'dvs'. _, _, height, width = dummy_input.shape self._add_dvs_node(dvs_input_shape=(height, width)) + # Consolidates the edges associated with a DVSLayer instance (ie., fix NIR edges extraction when DVS is a node in the graph). + fix_dvs_module_edges(edges=self._edges, indx_2_module_map=self._indx_2_module_map) - # Determine entry points to graph - self._entry_nodes = self._get_entry_nodes(self._edges) - - print('----------------------------------------') - print(self._edges) - for key, val in self._name_2_indx_map.items(): - print(key, val) - print('----------------------------------------') - - # Consolidates the edges associated with a DVSLayer instance. - fix_dvs_module_edges(edges=self._edges, indx_2_module_map=self._indx_2_module_map) + self._entry_nodes # Verify that graph is compatible self.verify_graph_integrity() @@ -247,7 +242,9 @@ def verify_graph_integrity(self): def _add_dvs_node(self, dvs_input_shape: Tuple[int, int]) -> None: """ In-place modification of `self._name_2_indx_map`, `self._indx_2_module_map`, and `self._edges` to accomodate the - creation of an extra node in the graph representing the DVS camera of the chip. + creation of an extra node in the graph representing the DVS camera of the chip. The DVSLayer node will point to every + other node that is up to this point an entry node of the original graph, so `self._entry_nodes` is modified in-place + to have only one entry: the index of the DVS node. Parameters ---------- @@ -263,6 +260,8 @@ def _add_dvs_node(self, dvs_input_shape: Tuple[int, int]) -> None: self._indx_2_module_map[self._name_2_indx_map['dvs']] = DVSLayer(input_shape=dvs_input_shape) # set DVS node as input to each entry node of the graph. self._edges.update({(self._name_2_indx_map['dvs'], entry_node) for entry_node in self._entry_nodes}) + # DVSLayer node becomes the only entrypoint of the graph. + self._entry_nodes = {self._name_2_indx_map['dvs']} def _need_dvs_node(self, model: nn.Module, dvs_input: bool) -> bool: """ Returns whether or not a node will need to be added to represent a `DVSLayer` instance. A new node will have diff --git a/sinabs/backend/dynapcnn/sinabs_edges_handler.py b/sinabs/backend/dynapcnn/sinabs_edges_handler.py index 3b73c877..f524f784 100644 --- a/sinabs/backend/dynapcnn/sinabs_edges_handler.py +++ b/sinabs/backend/dynapcnn/sinabs_edges_handler.py @@ -10,7 +10,7 @@ from torch import Size, nn -from .connectivity_specs import VALID_SINABS_EDGE_TYPES, DVS +from .connectivity_specs import VALID_SINABS_EDGE_TYPES from .exceptions import ( InvalidEdge, InvalidGraphStructure, @@ -46,26 +46,32 @@ def fix_dvs_module_edges(edges: Set[Edge], indx_2_module_map: Dict[int, nn.Modul feeds that in the sequence `DVS.pool -> DVS.crop -> DVS.flip`, so we remove edges involving these nodes that are internaly implementend in the DVSLayer instance from the graph and point the node of DVSLayer directly to the layer/module it is suppoed to forward its data to. + + Modifies `indx_2_module_map` in-place to remove the nodes (Crop2d, FlipDims and DVSLayer's pooling) defined within the DVSLayer + instance since those are not independent nodes of the final graph. Currently, this is also removing a self-recurrent node with edge `(FlipDims, FlipDims)` that is created when forwarding via DVSLayer. - The 'fix_' is to imply there's something odd with the extracted adges for the forward pass implemented by - the DVSLayer. For now this function is fixing these edges to have them representing the information flow through - this layer as **it should be** but the graph tracing of NIR should be looked into to solve the root problem. - Parameters ---------- - edges (set): tuples describing the connections between layers in `spiking_model`. - indx_2_module_map (dict): the mapping between a node (`key` as an `int`) and its module (`value` as a `nn.Module`). """ + # TODO - the 'fix_' is to imply there's something odd with the extracted adges for the forward pass implemented by + # the DVSLayer. For now this function is fixing these edges to have them representing the information flow through + # this layer as **it should be** but the graph tracing of NIR should be looked into to solve the root problem. # spot nodes (ie, modules) used in a DVSLayer instance's forward pass (including the DVSLayer node itself). dvslayer_nodes = { index: module for index, module in indx_2_module_map.items() - if any(isinstance(module, dvs_node) for dvs_node in DVS) + if any(isinstance(module, dvs_node) for dvs_node in (DVSLayer, Crop2d, FlipDims)) } + if len(dvslayer_nodes) == 1: + # No module within the DVSLayer instance appears as an independent node - nothing to do here. + return + # TODO - a `SumPool2d` is also a node that's used inside a DVSLayer instance. In what follows we try to find it # by looking for pooling nodes that appear in a (pool, crop) edge - the assumption being that if the pooling is # inputing into a crop layer than the pool is inside the DVSLayer instance. It feels like a hacky way to do it @@ -94,6 +100,11 @@ def fix_dvs_module_edges(edges: Set[Edge], indx_2_module_map: Dict[int, nn.Modul if any(len(node) > 1 for node in [dvs_node, dvs_pool_node, dvs_crop_node, dvs_flip_node]): raise ValueError(f'Internal DVS nodes should be single instances but multiple have been found: dvs_node: {len(dvs_node)} dvs_pool_node: {len(dvs_pool_node)} dvs_crop_node: {len(dvs_crop_node)} dvs_flip_node: {len(dvs_flip_node)}') + # Remove dvs_pool, dvs_crop and dvs_flip nodes from `indx_2_module_map` (these operate within the DVS, not as independent nodes of the final graph). + indx_2_module_map.pop(dvs_pool_node[-1]) + indx_2_module_map.pop(dvs_crop_node[-1]) + indx_2_module_map.pop(dvs_flip_node[-1]) + # dvs_pool, dvs_crop and dvs_flip are internal nodes of the DVSLayer: we only want an edge from 'dvs' node to the entry points of the network. edges.update({(dvs_node[-1], node) for node in entry_nodes})