diff --git a/_modules/hippynn/graphs/nodes/networks.html b/_modules/hippynn/graphs/nodes/networks.html index 1182a80a..4a37790f 100644 --- a/_modules/hippynn/graphs/nodes/networks.html +++ b/_modules/hippynn/graphs/nodes/networks.html @@ -81,6 +81,7 @@

Source code for hippynn.graphs.nodes.networks

""" from .tags import Encoder, PairIndexer, Network, AtomIndexer from .base import _BaseNode, AutoKw, ExpandParents, SingleNode +from .base.multi import IndexNode from .indexers import OneHotEncoder, PaddingIndexer, acquire_encoding_padding from .pairs import OpenPairIndexer, PeriodicPairIndexer from .inputs import SpeciesNode, PositionsNode, CellNode @@ -172,12 +173,14 @@

Source code for hippynn.graphs.nodes.networks

parents = self.expand_parents( parents, species_set=net_module.species_set, dist_hard_max=net_module.dist_hard_max, periodic=periodic ) + super().__init__(name, parents, module=net_module) - super().__init__(name, parents, module=net_module)
+ _make_feature_nodes(self)
+
[docs] class HipnnVec(DefaultNetworkExpansion, AutoKw, Network, SingleNode): @@ -212,7 +215,9 @@

Source code for hippynn.graphs.nodes.networks

parents, species_set=net_module.species_set, dist_hard_max=net_module.dist_hard_max, periodic=periodic ) - super().__init__(name, parents, module=net_module)
+ super().__init__(name, parents, module=net_module) + + _make_feature_nodes(self)
@@ -226,6 +231,35 @@

Source code for hippynn.graphs.nodes.networks

_auto_module_class = network_modules.hipnn.HipnnQuad
+ + +def _make_feature_nodes(network_node): + """ + This function can be used on a network to make nodes that refer to the individual feature blocks. + :param network_node: the input network, which is modified in-place + :return: None + """ + import warnings + warnings.warn("This function is included for backwards compatibility and may be removed in a future release. " + "The preferred way to access these nodes is through `network.feature_nodes`, which is available on " + "networks created with this version of hippynn or later.") + + if hasattr(network_node, "feature_nodes"): + return network_node.feature_nodes + + net_module = network_node.torch_module + n_interactions = net_module.ni + + feature_nodes = [] + + index_state = IdxType.Atoms + parents = (network_node,) + for i in range(n_interactions + 1): + name = f"{network_node.name}_features_{i}" + fnode = IndexNode(name=name, parents=parents, index=i, index_state=index_state) + feature_nodes.append(fnode) + network_node.feature_nodes = feature_nodes +