Skip to content

Commit

Permalink
Update documentations
Browse files Browse the repository at this point in the history
  • Loading branch information
actions-user committed Apr 10, 2024
1 parent fa0c4d7 commit bb182b6
Show file tree
Hide file tree
Showing 9 changed files with 59 additions and 37 deletions.
61 changes: 36 additions & 25 deletions _modules/hippynn/pretraining.html
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ <h1>Source code for hippynn.pretraining</h1><div class="highlight"><pre>
<span class="sd">Things to do before training, i.e. initialization of network and diagnostics.</span>
<span class="sd">&quot;&quot;&quot;</span>

<span class="kn">import</span> <span class="nn">warnings</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="kn">import</span> <span class="nn">torch</span>

Expand All @@ -93,11 +94,11 @@ <h1>Source code for hippynn.pretraining</h1><div class="highlight"><pre>
<span class="kn">from</span> <span class="nn">.networks.hipnn</span> <span class="kn">import</span> <span class="n">compute_hipnn_e0</span>


<div class="viewcode-block" id="set_e0_values">
<a class="viewcode-back" href="../../api_documentation/hippynn.pretraining.html#hippynn.pretraining.set_e0_values">[docs]</a>
<span class="k">def</span> <span class="nf">set_e0_values</span><span class="p">(</span>
<div class="viewcode-block" id="hierarchical_energy_initialization">
<a class="viewcode-back" href="../../api_documentation/hippynn.pretraining.html#hippynn.pretraining.hierarchical_energy_initialization">[docs]</a>
<span class="k">def</span> <span class="nf">hierarchical_energy_initialization</span><span class="p">(</span>
<span class="n">energy_module</span><span class="p">,</span>
<span class="n">database</span><span class="p">,</span>
<span class="n">database</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">trainable_after</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">decay_factor</span><span class="o">=</span><span class="mf">1e-2</span><span class="p">,</span>
<span class="n">encoder</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
Expand All @@ -109,12 +110,13 @@ <h1>Source code for hippynn.pretraining</h1><div class="highlight"><pre>
<span class="sd"> Computes values for the non-interacting energy using the training data.</span>

<span class="sd"> :param energy_module: HEnergyNode or torch module for energy prediction</span>
<span class="sd"> :param database: InterfaceDB object to get training data</span>
<span class="sd"> :param trainable_after: Determines if it should change .requires_grad attribute for the E0 parameters.</span>
<span class="sd"> :param database: InterfaceDB object to get training data, required if model contains E0 term</span>
<span class="sd"> :param trainable_after: Determines if it should change .requires_grad attribute for the E0 parameters</span>
<span class="sd"> :param decay_factor: change initialized weights of further energy layers by ``df**N`` for layer N</span>
<span class="sd"> :param network_module: network for running the species encoding. Can be auto-identified from energy node</span>
<span class="sd"> :param encoder: species encoder, can be auto-identified from energy node</span>
<span class="sd"> :param energy_name: name for the energy variable, can be auto-identified from energy node</span>
<span class="sd"> :param species_name: name for the species variable, can be auto-identified from energy node</span>
<span class="sd"> :param peratom:</span>
<span class="sd"> :return: None</span>
<span class="sd"> &quot;&quot;&quot;</span>

Expand All @@ -131,32 +133,41 @@ <h1>Source code for hippynn.pretraining</h1><div class="highlight"><pre>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">encoder</span><span class="p">,</span> <span class="n">_BaseNode</span><span class="p">):</span>
<span class="n">encoder</span> <span class="o">=</span> <span class="n">encoder</span><span class="o">.</span><span class="n">torch_module</span>

<span class="n">train_data</span> <span class="o">=</span> <span class="n">database</span><span class="o">.</span><span class="n">splits</span><span class="p">[</span><span class="s2">&quot;train&quot;</span><span class="p">]</span>

<span class="n">z_vals</span> <span class="o">=</span> <span class="n">train_data</span><span class="p">[</span><span class="n">species_name</span><span class="p">]</span>
<span class="n">t_vals</span> <span class="o">=</span> <span class="n">train_data</span><span class="p">[</span><span class="n">energy_name</span><span class="p">]</span>

<span class="n">encoder</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">t_vals</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">eovals</span> <span class="o">=</span> <span class="n">compute_hipnn_e0</span><span class="p">(</span><span class="n">encoder</span><span class="p">,</span> <span class="n">z_vals</span><span class="p">,</span> <span class="n">t_vals</span><span class="p">,</span> <span class="n">peratom</span><span class="o">=</span><span class="n">peratom</span><span class="p">)</span>
<span class="n">eo_layer</span> <span class="o">=</span> <span class="n">energy_module</span><span class="o">.</span><span class="n">layers</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>

<span class="k">if</span> <span class="ow">not</span> <span class="n">eo_layer</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="n">eovals</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]:</span>
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span><span class="s2">&quot;The function set_eo_values does not currently work with custom InputNodes.&quot;</span><span class="p">)</span>
<span class="c1"># If model has E0 term, set its initial value using the database provided</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">energy_module</span><span class="o">.</span><span class="n">first_is_interacting</span><span class="p">:</span>
<span class="k">if</span> <span class="n">database</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Database must be provided if model includes E0 energy term.&quot;</span><span class="p">)</span>

<span class="n">eo_layer</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">data</span> <span class="o">=</span> <span class="n">eovals</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Computed E0 energies:&quot;</span><span class="p">,</span> <span class="n">eovals</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Computed E0 energies:&quot;</span><span class="p">,</span> <span class="n">eovals</span><span class="p">)</span>
<span class="n">eo_layer</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">data</span> <span class="o">=</span> <span class="n">eovals</span><span class="o">.</span><span class="n">expand_as</span><span class="p">(</span><span class="n">eo_layer</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">data</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Computed E0 energies:&quot;</span><span class="p">,</span> <span class="n">eovals</span><span class="p">)</span>
<span class="n">eo_layer</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">data</span> <span class="o">=</span> <span class="n">eovals</span><span class="o">.</span><span class="n">expand_as</span><span class="p">(</span><span class="n">eo_layer</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">data</span><span class="p">)</span>
<span class="n">train_data</span> <span class="o">=</span> <span class="n">database</span><span class="o">.</span><span class="n">splits</span><span class="p">[</span><span class="s2">&quot;train&quot;</span><span class="p">]</span>

<span class="n">z_vals</span> <span class="o">=</span> <span class="n">train_data</span><span class="p">[</span><span class="n">species_name</span><span class="p">]</span>
<span class="n">t_vals</span> <span class="o">=</span> <span class="n">train_data</span><span class="p">[</span><span class="n">energy_name</span><span class="p">]</span>

<span class="n">encoder</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">t_vals</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">eovals</span> <span class="o">=</span> <span class="n">compute_hipnn_e0</span><span class="p">(</span><span class="n">encoder</span><span class="p">,</span> <span class="n">z_vals</span><span class="p">,</span> <span class="n">t_vals</span><span class="p">,</span> <span class="n">peratom</span><span class="o">=</span><span class="n">peratom</span><span class="p">)</span>
<span class="n">eo_layer</span> <span class="o">=</span> <span class="n">energy_module</span><span class="o">.</span><span class="n">layers</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>

<span class="k">if</span> <span class="ow">not</span> <span class="n">eo_layer</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="n">eovals</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;The shape of the computed E0 values does not match the shape expected by the model.&quot;</span><span class="p">)</span>

<span class="n">eo_layer</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">data</span> <span class="o">=</span> <span class="n">eovals</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Computed E0 energies:&quot;</span><span class="p">,</span> <span class="n">eovals</span><span class="p">)</span>
<span class="n">eo_layer</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">data</span> <span class="o">=</span> <span class="n">eovals</span><span class="o">.</span><span class="n">expand_as</span><span class="p">(</span><span class="n">eo_layer</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">data</span><span class="p">)</span>
<span class="n">eo_layer</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">requires_grad_</span><span class="p">(</span><span class="n">trainable_after</span><span class="p">)</span>

<span class="n">eo_layer</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">requires_grad_</span><span class="p">(</span><span class="n">trainable_after</span><span class="p">)</span>
<span class="c1"># Decay layers E1, E2, etc... according to decay_factor</span>
<span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="n">energy_module</span><span class="o">.</span><span class="n">layers</span><span class="p">[</span><span class="mi">1</span><span class="p">:]:</span>
<span class="n">layer</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">data</span> <span class="o">*=</span> <span class="n">decay_factor</span>
<span class="n">layer</span><span class="o">.</span><span class="n">bias</span><span class="o">.</span><span class="n">data</span> <span class="o">*=</span> <span class="n">decay_factor</span>
<span class="n">decay_factor</span> <span class="o">*=</span> <span class="n">decay_factor</span></div>


<div class="viewcode-block" id="set_e0_values">
<a class="viewcode-back" href="../../api_documentation/hippynn.pretraining.html#hippynn.pretraining.set_e0_values">[docs]</a>
<span class="k">def</span> <span class="nf">set_e0_values</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="n">warnings</span><span class="o">.</span><span class="n">warn</span><span class="p">(</span><span class="s2">&quot;The function set_e0_values is depreciated. Please use the hierarchical_energy_initialization function instead.&quot;</span><span class="p">)</span>
<span class="k">return</span> <span class="n">hierarchical_energy_initialization</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div>


<span class="k">def</span> <span class="nf">_setup_min_dist_graph</span><span class="p">(</span>
<span class="n">species_name</span><span class="p">,</span>
Expand Down
4 changes: 2 additions & 2 deletions _sources/examples/minimal_workflow.rst.txt
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ Now we'll load a database::

Now that we have a database and a model, we can fit the non-interacting energies using the training set in the database::

from hippynn.pretraining import set_e0_values
set_e0_values(henergy,database,trainable_after=False)
from hippynn.pretraining import hierarchical_energy_initialization
hierarchical_energy_initialization(henergy,database,trainable_after=False)

We're almost there. We specify the training procedure with ``SetupParams``. We need to have

Expand Down
1 change: 1 addition & 0 deletions api_documentation/hippynn.html
Original file line number Diff line number Diff line change
Expand Up @@ -1969,6 +1969,7 @@ <h1>hippynn package<a class="headerlink" href="#hippynn-package" title="Link to
<li class="toctree-l1"><a class="reference internal" href="hippynn.pretraining.html">pretraining module</a><ul>
<li class="toctree-l2"><a class="reference internal" href="hippynn.pretraining.html#hippynn.pretraining.calculate_max_system_force"><code class="docutils literal notranslate"><span class="pre">calculate_max_system_force()</span></code></a></li>
<li class="toctree-l2"><a class="reference internal" href="hippynn.pretraining.html#hippynn.pretraining.calculate_min_dists"><code class="docutils literal notranslate"><span class="pre">calculate_min_dists()</span></code></a></li>
<li class="toctree-l2"><a class="reference internal" href="hippynn.pretraining.html#hippynn.pretraining.hierarchical_energy_initialization"><code class="docutils literal notranslate"><span class="pre">hierarchical_energy_initialization()</span></code></a></li>
<li class="toctree-l2"><a class="reference internal" href="hippynn.pretraining.html#hippynn.pretraining.set_e0_values"><code class="docutils literal notranslate"><span class="pre">set_e0_values()</span></code></a></li>
</ul>
</li>
Expand Down
Loading

0 comments on commit bb182b6

Please sign in to comment.