Skip to content

Commit

Permalink
Refactor GraphExtractor.remove_nodes_by_class method
Browse files Browse the repository at this point in the history
  • Loading branch information
bauerfe committed Oct 2, 2024
1 parent 8efe58f commit 80c9ef9
Showing 1 changed file with 50 additions and 38 deletions.
88 changes: 50 additions & 38 deletions sinabs/backend/dynapcnn/nir_graph_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# contact : [email protected]

import copy
from typing import Dict, List, Tuple, Type, Set
from typing import Dict, List, Set, Tuple, Type

import nirtorch
import torch
Expand Down Expand Up @@ -114,40 +114,8 @@ def remove_nodes_by_class(
for new_idx, old_idx in enumerate(sorted(source2target.keys()))
}

# Parse new set of edges based on remapped node IDs
self._edges = {
(remapped_nodes[src], remapped_nodes[tgt])
for src, targets in source2target.items()
for tgt in targets
}

# Update name-to-index map based on new node indices
self._name_2_indx_map = {
name: remapped_nodes[old_idx]
for name, old_idx in self._name_2_indx_map.items()
if old_idx in remapped_nodes
}

# Update entry nodes based on new node indices
self._entry_nodes = {
remapped_nodes[old_idx]
for old_idx in self._entry_nodes
if old_idx in remapped_nodes
}

# Update io-shapes based on new node indices
self._nodes_io_shapes = {
remapped_nodes[old_idx]: shape
for old_idx, shape in self._nodes_io_shapes.items()
if old_idx in remapped_nodes
}

# Update sinabs module map based on new node indices
self._modules_map = {
remapped_nodes[old_idx]: module
for old_idx, module in self._modules_map.items()
if old_idx in remapped_nodes
}
# Update internal graph representation according to changes
self._update_internal_representation(remapped_nodes)

def get_node_io_shapes(self, node: int) -> Tuple[torch.Size, torch.Size]:
"""Returns the I/O tensors' shapes of `node`.
Expand Down Expand Up @@ -183,14 +151,14 @@ def _get_edges_from_nir(

# Assign a unique index to each node
name_2_indx_map = {
node.name: node_idx
for node_idx, node in enumerate(nir_graph.node_list)
node.name: node_idx for node_idx, node in enumerate(nir_graph.node_list)
}

# Extract edges for each node
edges = {
(name_2_indx_map[src.name], name_2_indx_map[tgt.name])
for src in nir_graph.node_list for tgt in src.outgoing_nodes
for src in nir_graph.node_list
for tgt in src.outgoing_nodes
}

# find entry nodes of the graph.
Expand All @@ -217,6 +185,50 @@ def _get_named_modules(self, model: nn.Module) -> Dict[int, nn.Module]:
if name in self._name_2_indx_map
}

def _update_internal_representation(self, remapped_nodes: Dict[int, int]):
"""Update internal attributes after remapping of nodes
Parameters
----------
remapped_nodes (dict): Maps previous (key) to new (value) node
indices. Nodes that were removed are not included.
"""

# Parse new set of edges based on remapped node IDs
self._edges = {
(remapped_nodes[src], remapped_nodes[tgt])
for src, targets in source2target.items()
for tgt in targets
}

# Update name-to-index map based on new node indices
self._name_2_indx_map = {
name: remapped_nodes[old_idx]
for name, old_idx in self._name_2_indx_map.items()
if old_idx in remapped_nodes
}

# Update entry nodes based on new node indices
self._entry_nodes = {
remapped_nodes[old_idx]
for old_idx in self._entry_nodes
if old_idx in remapped_nodes
}

# Update io-shapes based on new node indices
self._nodes_io_shapes = {
remapped_nodes[old_idx]: shape
for old_idx, shape in self._nodes_io_shapes.items()
if old_idx in remapped_nodes
}

# Update sinabs module map based on new node indices
self._modules_map = {
remapped_nodes[old_idx]: module
for old_idx, module in self._modules_map.items()
if old_idx in remapped_nodes
}

def _sort_graph_nodes(self) -> List[int]:
"""Sort graph nodes topologically.
Expand Down

0 comments on commit 80c9ef9

Please sign in to comment.