Skip to content

Commit

Permalink
DONE - DVS node given
Browse files Browse the repository at this point in the history
constructors of GraphExtractor and DynapcnnNetwork executing without errors
  • Loading branch information
Willian-Girao committed Oct 30, 2024
1 parent 430af98 commit 2988ce2
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 23 deletions.
14 changes: 9 additions & 5 deletions sinabs/backend/dynapcnn/nir_graph_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,15 @@ def __init__(self, spiking_model: nn.Module, dummy_input: torch.tensor, dvs_inpu
# input shape for `DVSLayer` instance that will be the module of the node 'dvs'.
_, _, height, width = dummy_input.shape
self._add_dvs_node(dvs_input_shape=(height, width))

# TODO - the calll bellow should be done outside this 'if' cuz the problem only
# appears when the DVSLayer is given as the first layer of `spiking_model`.
# Fix NIR edges extraction when DVS is a node in the graph.
fix_dvs_module_edges(edges=self._edges, indx_2_module_map=self._indx_2_module_map)

# Check for the need of fixing NIR edges extraction when DVS is a node in the graph. If DVS
# is used its node becomes the only entry node in the graph.
fix_dvs_module_edges(
edges=self._edges,
indx_2_module_map=self._indx_2_module_map,
name_2_indx_map=self._name_2_indx_map,
entry_nodes=self._entry_nodes,
)

# Verify that graph is compatible
self.verify_graph_integrity()
Expand Down
47 changes: 29 additions & 18 deletions sinabs/backend/dynapcnn/sinabs_edges_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,22 +41,25 @@ def get_dvs_node_from_mapper(dcnnl_map: Dict) -> Optional[Dict]:
return layer_info
return None

def fix_dvs_module_edges(edges: Set[Edge], indx_2_module_map: Dict[int, nn.Module]) -> None:
""" Modifies `edges` in-place to re-structure the edges related witht the DVSLayer instance. The DVSLayer's forward method
feeds that in the sequence `DVS.pool -> DVS.crop -> DVS.flip`, so we remove edges involving these nodes that are internaly
implementend in the DVSLayer instance from the graph and point the node of DVSLayer directly to the layer/module it is suppoed
to forward its data to.
Modifies `indx_2_module_map` in-place to remove the nodes (Crop2d, FlipDims and DVSLayer's pooling) defined within the DVSLayer
instance since those are not independent nodes of the final graph.
def fix_dvs_module_edges(edges: Set[Edge], indx_2_module_map: Dict[int, nn.Module], name_2_indx_map: Dict[str, int], entry_nodes: Set[Edge]) -> None:
""" All arguments are modified in-place to fix wrong node extractions from NIRtorch when a DVSLayer istance is the first layer in the network.
Currently, this is also removing a self-recurrent node with edge `(FlipDims, FlipDims)` that is
created when forwarding via DVSLayer.
Modifies `edges` to re-structure the edges related witht the DVSLayer instance. The DVSLayer's forward method feeds data in the
sequence 'DVS -> DVS.pool -> DVS.crop -> DVS.flip', so we remove edges involving these nodes (that are internaly implementend in
the DVSLayer) from the graph and point the node of DVSLayer to the node where it should send its output to. This is also removes
a self-recurrent node with edge '(FlipDims, FlipDims)' that is wrongly extracted.
Modifies `indx_2_module_map` and `name_2_indx_map` to remove the internal DVSLayer nodes (Crop2d, FlipDims and DVSLayer's pooling) since
these should not be independent nodes in the graph.
Modifies `entry_nodes` such that the DVSLayer becomes the only entry node of the graph.
Parameters
----------
- edges (set): tuples describing the connections between layers in `spiking_model`.
- indx_2_module_map (dict): the mapping between a node (`key` as an `int`) and its module (`value` as a `nn.Module`).
- name_2_indx_map (dict): Map from node names to unique indices.
- entry_nodes (set): IDs of nodes acting as entry points for the network (i.e., receiving external input).
"""
# TODO - the 'fix_' is to imply there's something odd with the extracted adges for the forward pass implemented by
# the DVSLayer. For now this function is fixing these edges to have them representing the information flow through
Expand All @@ -82,14 +85,12 @@ def fix_dvs_module_edges(edges: Set[Edge], indx_2_module_map: Dict[int, nn.Modul
})

# NIR is extracting and edge (FlipDims, FlipDims) from the DVSLayer: remove self-recurrent nodes from the graph.
edges = {(src, tgt) for (src, tgt) in edges if not (src == tgt and isinstance(indx_2_module_map[src], FlipDims))}
for edge in [(src, tgt) for (src, tgt) in edges if (src == tgt and isinstance(indx_2_module_map[src], FlipDims))]:
edges.remove(edge)

# Since NIR is not extracting the edges for the DVSLayer correctly, remove all edges involving the DVS.
edges = {(src, tgt) for (src, tgt) in edges if src not in dvslayer_nodes and tgt not in dvslayer_nodes}

# Get what the entry nodes should be without the DVS - these are the ones the DVS should point to.
all_sources, all_targets = zip(*edges)
entry_nodes = set(all_sources) - set(all_targets)
for edge in [(src, tgt) for (src, tgt) in edges if (src in dvslayer_nodes or tgt in dvslayer_nodes)]:
edges.remove(edge)

# Get node's indexes based on the module type - just for validation.
dvs_node = [key for key, value in dvslayer_nodes.items() if isinstance(value, DVSLayer)]
Expand All @@ -104,9 +105,19 @@ def fix_dvs_module_edges(edges: Set[Edge], indx_2_module_map: Dict[int, nn.Modul
indx_2_module_map.pop(dvs_pool_node[-1])
indx_2_module_map.pop(dvs_crop_node[-1])
indx_2_module_map.pop(dvs_flip_node[-1])

# Remove internal DVS modeules from name/index map.
for name in [name for name, index in name_2_indx_map.items() if index in [dvs_pool_node[-1], dvs_crop_node[-1], dvs_flip_node[-1]]]:
name_2_indx_map.pop(name)

# dvs_pool, dvs_crop and dvs_flip are internal nodes of the DVSLayer: we only want an edge from 'dvs' node to the entry points of the network.
edges.update({(dvs_node[-1], node) for node in entry_nodes})
# Add edges from 'dvs' node to the entry point of the graph.
all_sources, all_targets = zip(*edges)
local_entry_nodes = set(all_sources) - set(all_targets)
edges.update({(dvs_node[-1], node) for node in local_entry_nodes})

# DVS becomes the only entry node of the graph.
entry_nodes.clear()
entry_nodes.add(dvs_node[-1])

def collect_dynapcnn_layer_info(
indx_2_module_map: Dict[int, nn.Module],
Expand Down

0 comments on commit 2988ce2

Please sign in to comment.