diff --git a/_modules/hippynn/graphs/nodes/networks.html b/_modules/hippynn/graphs/nodes/networks.html
index 1182a80a..4a37790f 100644
--- a/_modules/hippynn/graphs/nodes/networks.html
+++ b/_modules/hippynn/graphs/nodes/networks.html
@@ -81,6 +81,7 @@
Source code for hippynn.graphs.nodes.networks
"""
from .tags import Encoder, PairIndexer, Network, AtomIndexer
from .base import _BaseNode, AutoKw, ExpandParents, SingleNode
+from .base.multi import IndexNode
from .indexers import OneHotEncoder, PaddingIndexer, acquire_encoding_padding
from .pairs import OpenPairIndexer, PeriodicPairIndexer
from .inputs import SpeciesNode, PositionsNode, CellNode
@@ -172,12 +173,14 @@ Source code for hippynn.graphs.nodes.networks
parents = self.expand_parents(
parents, species_set=net_module.species_set, dist_hard_max=net_module.dist_hard_max, periodic=periodic
)
+ super().__init__(name, parents, module=net_module)
- super().__init__(name, parents, module=net_module)
+ _make_feature_nodes(self)
+
[docs]
class HipnnVec(DefaultNetworkExpansion, AutoKw, Network, SingleNode):
@@ -212,7 +215,9 @@
Source code for hippynn.graphs.nodes.networks
parents, species_set=net_module.species_set, dist_hard_max=net_module.dist_hard_max, periodic=periodic
)
- super().__init__(name, parents, module=net_module)
+
super().__init__(name, parents, module=net_module)
+
+
_make_feature_nodes(self)
@@ -226,6 +231,35 @@ Source code for hippynn.graphs.nodes.networks
_auto_module_class = network_modules.hipnn.HipnnQuad
+
+
+def _make_feature_nodes(network_node):
+ """
+ This function can be used on a network to make nodes that refer to the individual feature blocks.
+ :param network_node: the input network, which is modified in-place
+ :return: None
+ """
+ import warnings
+ warnings.warn("This function is included for backwards compatibility and may be removed in a future release. "
+ "The preferred way to access these nodes is through `network.feature_nodes`, which is available on "
+ "networks created with this version of hippynn or later.")
+
+ if hasattr(network_node, "feature_nodes"):
+ return network_node.feature_nodes
+
+ net_module = network_node.torch_module
+ n_interactions = net_module.ni
+
+ feature_nodes = []
+
+ index_state = IdxType.Atoms
+ parents = (network_node,)
+ for i in range(n_interactions + 1):
+ name = f"{network_node.name}_features_{i}"
+ fnode = IndexNode(name=name, parents=parents, index=i, index_state=index_state)
+ feature_nodes.append(fnode)
+ network_node.feature_nodes = feature_nodes
+