-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor GraphExtractor.remove_nodes_by_class method
- Loading branch information
Showing
1 changed file
with
50 additions
and
38 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# contact : [email protected] | ||
|
||
import copy | ||
from typing import Dict, List, Tuple, Type, Set | ||
from typing import Dict, List, Set, Tuple, Type | ||
|
||
import nirtorch | ||
import torch | ||
|
@@ -114,40 +114,8 @@ def remove_nodes_by_class( | |
for new_idx, old_idx in enumerate(sorted(source2target.keys())) | ||
} | ||
|
||
# Parse new set of edges based on remapped node IDs | ||
self._edges = { | ||
(remapped_nodes[src], remapped_nodes[tgt]) | ||
for src, targets in source2target.items() | ||
for tgt in targets | ||
} | ||
|
||
# Update name-to-index map based on new node indices | ||
self._name_2_indx_map = { | ||
name: remapped_nodes[old_idx] | ||
for name, old_idx in self._name_2_indx_map.items() | ||
if old_idx in remapped_nodes | ||
} | ||
|
||
# Update entry nodes based on new node indices | ||
self._entry_nodes = { | ||
remapped_nodes[old_idx] | ||
for old_idx in self._entry_nodes | ||
if old_idx in remapped_nodes | ||
} | ||
|
||
# Update io-shapes based on new node indices | ||
self._nodes_io_shapes = { | ||
remapped_nodes[old_idx]: shape | ||
for old_idx, shape in self._nodes_io_shapes.items() | ||
if old_idx in remapped_nodes | ||
} | ||
|
||
# Update sinabs module map based on new node indices | ||
self._modules_map = { | ||
remapped_nodes[old_idx]: module | ||
for old_idx, module in self._modules_map.items() | ||
if old_idx in remapped_nodes | ||
} | ||
# Update internal graph representation according to changes | ||
self._update_internal_representation(remapped_nodes) | ||
|
||
def get_node_io_shapes(self, node: int) -> Tuple[torch.Size, torch.Size]: | ||
"""Returns the I/O tensors' shapes of `node`. | ||
|
@@ -183,14 +151,14 @@ def _get_edges_from_nir( | |
|
||
# Assign a unique index to each node | ||
name_2_indx_map = { | ||
node.name: node_idx | ||
for node_idx, node in enumerate(nir_graph.node_list) | ||
node.name: node_idx for node_idx, node in enumerate(nir_graph.node_list) | ||
} | ||
|
||
# Extract edges for each node | ||
edges = { | ||
(name_2_indx_map[src.name], name_2_indx_map[tgt.name]) | ||
for src in nir_graph.node_list for tgt in src.outgoing_nodes | ||
for src in nir_graph.node_list | ||
for tgt in src.outgoing_nodes | ||
} | ||
|
||
# find entry nodes of the graph. | ||
|
@@ -217,6 +185,50 @@ def _get_named_modules(self, model: nn.Module) -> Dict[int, nn.Module]: | |
if name in self._name_2_indx_map | ||
} | ||
|
||
def _update_internal_representation(self, remapped_nodes: Dict[int, int]): | ||
"""Update internal attributes after remapping of nodes | ||
Parameters | ||
---------- | ||
remapped_nodes (dict): Maps previous (key) to new (value) node | ||
indices. Nodes that were removed are not included. | ||
""" | ||
|
||
# Parse new set of edges based on remapped node IDs | ||
self._edges = { | ||
(remapped_nodes[src], remapped_nodes[tgt]) | ||
for src, targets in source2target.items() | ||
for tgt in targets | ||
} | ||
|
||
# Update name-to-index map based on new node indices | ||
self._name_2_indx_map = { | ||
name: remapped_nodes[old_idx] | ||
for name, old_idx in self._name_2_indx_map.items() | ||
if old_idx in remapped_nodes | ||
} | ||
|
||
# Update entry nodes based on new node indices | ||
self._entry_nodes = { | ||
remapped_nodes[old_idx] | ||
for old_idx in self._entry_nodes | ||
if old_idx in remapped_nodes | ||
} | ||
|
||
# Update io-shapes based on new node indices | ||
self._nodes_io_shapes = { | ||
remapped_nodes[old_idx]: shape | ||
for old_idx, shape in self._nodes_io_shapes.items() | ||
if old_idx in remapped_nodes | ||
} | ||
|
||
# Update sinabs module map based on new node indices | ||
self._modules_map = { | ||
remapped_nodes[old_idx]: module | ||
for old_idx, module in self._modules_map.items() | ||
if old_idx in remapped_nodes | ||
} | ||
|
||
def _sort_graph_nodes(self) -> List[int]: | ||
"""Sort graph nodes topologically. | ||
|