diff --git a/examples/python/pytket-qujax_heisenberg_vqe.py b/examples/python/pytket-qujax_heisenberg_vqe.py
index 3c050c9e..fc840f49 100644
--- a/examples/python/pytket-qujax_heisenberg_vqe.py
+++ b/examples/python/pytket-qujax_heisenberg_vqe.py
@@ -5,10 +5,12 @@
from pytket.circuit.display import render_circuit_jupyter
import matplotlib.pyplot as plt
+
+# ## Let's start with a TKET circuit
+
import qujax
from pytket.extensions.qujax.qujax_convert import tk_to_qujax
-# # Let's start with a tket circuit
# We place barriers to stop tket automatically rearranging gates and we also store the number of circuit parameters as we'll need this later.
@@ -47,7 +49,7 @@ def get_circuit(n_qubits, depth):
circuit, n_params = get_circuit(n_qubits, depth)
render_circuit_jupyter(circuit)
-# # Now let's invoke qujax
+# ## Now let's invoke qujax
# The `pytket.extensions.qujax.tk_to_qujax` function will generate a parameters -> statetensor function for us.
param_to_st = tk_to_qujax(circuit)
@@ -72,7 +74,7 @@ def get_circuit(n_qubits, depth):
sample_probs = jnp.square(jnp.abs(statevector))
plt.bar(jnp.arange(statevector.size), sample_probs)
-# # Cost function
+# ## Cost function
# Now we have our `param_to_st` function we are free to define a cost function that acts on bitstrings (e.g. maxcut) or integers by directly wrapping a function around `param_to_st`. However, cost functions defined via quantum Hamiltonians are a bit more involved.
# Fortunately, we can encode an Hamiltonian in JAX via the `qujax.get_statetensor_to_expectation_func` function which generates a statetensor -> expected value function for us.
@@ -119,7 +121,7 @@ def get_circuit(n_qubits, depth):
)
param_to_expectation(new_params)
-# # We can now use autodiff for fast, exact gradients within a VQE algorithm
+# ## Exact gradients within a VQE algorithm
# The `param_to_expectation` function we created is a pure JAX function and outputs a scalar. This means we can pass it to `jax.grad` (or even better `jax.value_and_grad`).
cost_and_grad = value_and_grad(param_to_expectation)
@@ -128,7 +130,7 @@ def get_circuit(n_qubits, depth):
cost_and_grad(params)
-# # Now we have all the tools we need to design our VQE!
+# ## Now we have all the tools we need to design our VQE!
# We'll just use vanilla gradient descent with a constant stepsize
@@ -164,7 +166,7 @@ def vqe(init_param, n_steps, stepsize):
# Pretty good!
-# # `jax.jit` speedup
+# ## `jax.jit` speedup
# One last thing... We can significantly speed up the VQE above via the `jax.jit`. In our current implementation, the expensive `cost_and_grad` function is compiled to [XLA](https://www.tensorflow.org/xla) and then executed at each call. By invoking `jax.jit` we ensure that the function is compiled only once (on the first call) and then simply executed at each future call - this is much faster!
cost_and_grad = jit(cost_and_grad)
diff --git a/examples/pytket-qujax_heisenberg_vqe.ipynb b/examples/pytket-qujax_heisenberg_vqe.ipynb
index 2ec55022..36f288f4 100644
--- a/examples/pytket-qujax_heisenberg_vqe.ipynb
+++ b/examples/pytket-qujax_heisenberg_vqe.ipynb
@@ -1 +1,483 @@
-{"cells": [{"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["from pytket import Circuit\n", "from pytket.circuit.display import render_circuit_jupyter\n", "from jax import numpy as jnp, random, vmap, grad, value_and_grad, jit\n", "import matplotlib.pyplot as plt"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["import qujax\n", "from pytket.extensions.qujax import tk_to_qujax"]}, {"cell_type": "markdown", "metadata": {}, "source": ["# Let's start with a tket circuit
\n", "We place barriers to stop tket automatically rearranging gates and we also store the number of circuit parameters as we'll need this later."]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["def get_circuit(n_qubits, depth):\n", " n_params = 2 * n_qubits * (depth + 1)\n", " param = jnp.zeros((n_params,))\n", " circuit = Circuit(n_qubits)\n", " k = 0\n", " for i in range(n_qubits):\n", " circuit.H(i)\n", " for i in range(n_qubits):\n", " circuit.Rx(param[k], i)\n", " k += 1\n", " for i in range(n_qubits):\n", " circuit.Ry(param[k], i)\n", " k += 1\n", " for _ in range(depth):\n", " for i in range(0, n_qubits - 1):\n", " circuit.CZ(i, i + 1)\n", " circuit.add_barrier(range(0, n_qubits))\n", " for i in range(n_qubits):\n", " circuit.Rx(param[k], i)\n", " k += 1\n", " for i in range(n_qubits):\n", " circuit.Ry(param[k], i)\n", " k += 1\n", " return circuit, n_params"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["n_qubits = 4\n", "depth = 2\n", "circuit, n_params = get_circuit(n_qubits, depth)\n", "render_circuit_jupyter(circuit)"]}, {"cell_type": "markdown", "metadata": {}, "source": ["# Now let's invoke qujax
\n", "The `pytket.extensions.qujax.tk_to_qujax` function will generate a parameters -> statetensor function for us."]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["param_to_st = tk_to_qujax(circuit)"]}, {"cell_type": "markdown", "metadata": {}, "source": ["Let's try it out on some random parameters values. Be aware that's JAX's random number generator requires a `jax.random.PRNGkey` every time it's called - more info on that [here](https://jax.readthedocs.io/en/latest/jax.random.html).
\n", "Be aware that we still have convention where parameters are specified as multiples of $\\pi$ - that is in [0,2]."]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["params = random.uniform(random.PRNGKey(0), shape=(n_params,), minval=0., maxval=2.)\n", "statetensor = param_to_st(params)\n", "print(statetensor)\n", "print(statetensor.shape)"]}, {"cell_type": "markdown", "metadata": {}, "source": ["Note that this function also has an optional second argument where an initiating `statetensor_in` can be provided. If it is not provided it will default to the all 0s state (as we use here)."]}, {"cell_type": "markdown", "metadata": {}, "source": ["We can obtain statevector by simply calling `.flatten()`"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["statevector = statetensor.flatten()\n", "statevector.shape"]}, {"cell_type": "markdown", "metadata": {}, "source": ["And sampling probabilities by squaring the absolute value of the statevector"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["sample_probs = jnp.square(jnp.abs(statevector))\n", "plt.bar(jnp.arange(statevector.size), sample_probs);"]}, {"cell_type": "markdown", "metadata": {}, "source": ["# Cost function"]}, {"cell_type": "markdown", "metadata": {}, "source": ["Now we have our `param_to_st` function we are free to define a cost function that acts on bitstrings (e.g. maxcut) or integers by directly wrapping a function around `param_to_st`. However, cost functions defined via quantum Hamiltonians are a bit more involved.
\n", "Fortunately, we can encode an Hamiltonian in JAX via the `qujax.get_statetensor_to_expectation_func` function which generates a statetensor -> expected value function for us.
\n", "It takes three arguments as input
\n", "- `gate_seq_seq`: A list of string (or array) lists encoding the gates in each term of the Hamiltonian. I.e. `[['X','X'], ['Y','Y'], ['Z','Z']]` corresponds to $H = aX_iX_j + bY_kY_l + cZ_mZ_n$ with qubit indices $i,j,k,l,m,n$ specified in the second argument and coefficients $a,b,c$ specified in the third argument
\n", "- `qubit_inds_seq`: A list of integer lists encoding which qubit indices to apply the aforementioned gates. I.e. `[[0, 1],[0,1],[0,1]]`. Must have the same structure as `gate_seq_seq` above.
\n", "- `coefficients`: A list of floats encoding any coefficients in the Hamiltonian. I.e. `[2.3, 0.8, 1.2]` corresponds to $a=2.3,b=0.8,c=1.2$ above. Must have the same length as the two above arguments."]}, {"cell_type": "markdown", "metadata": {}, "source": ["More specifically let's consider the problem of finding the ground state of the quantum Heisenberg Hamiltonian
\n", "$$ H = \\sum_{i=1}^{n_\\text{qubits}-1} X_i X_{i+1} + Y_i Y_{i+1} + Z_i Z_{i+1}. $$
\n", "As described, we define the Hamiltonian via its gate strings, qubit indices and coefficients."]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["hamiltonian_gates = [['X', 'X'], ['Y', 'Y'], ['Z', 'Z']] * (n_qubits - 1)\n", "hamiltonian_qubit_inds = [[int(i), int(i) + 1] for i in jnp.repeat(jnp.arange(n_qubits), 3)]\n", "coefficients = [1.] * len(hamiltonian_qubit_inds)"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["print('Gates:\\t', hamiltonian_gates)\n", "print('Qubits:\\t', hamiltonian_qubit_inds)\n", "print('Coefficients:\\t', coefficients)"]}, {"cell_type": "markdown", "metadata": {}, "source": ["Now let's get the Hamiltonian as a pure JAX function"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["st_to_expectation = qujax.get_statetensor_to_expectation_func(hamiltonian_gates,\n", " hamiltonian_qubit_inds,\n", " coefficients)"]}, {"cell_type": "markdown", "metadata": {}, "source": ["Let's check it works on the statetensor we've already generated."]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["expected_val = st_to_expectation(statetensor)\n", "expected_val"]}, {"cell_type": "markdown", "metadata": {}, "source": ["Now let's wrap the `param_to_st` and `st_to_expectation` together to give us an all in one `param_to_expectation` cost function."]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["param_to_expectation = lambda param: st_to_expectation(param_to_st(param))"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["param_to_expectation(params)"]}, {"cell_type": "markdown", "metadata": {}, "source": ["Sanity check that a different, randomly generated set of parameters gives us a new expected value."]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["new_params = random.uniform(random.PRNGKey(1), shape=(n_params,), minval=0., maxval=2.)\n", "param_to_expectation(new_params)"]}, {"cell_type": "markdown", "metadata": {}, "source": ["# We can now use autodiff for fast, exact gradients within a VQE algorithm
\n", "The `param_to_expectation` function we created is a pure JAX function and outputs a scalar. This means we can pass it to `jax.grad` (or even better `jax.value_and_grad`)."]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["cost_and_grad = value_and_grad(param_to_expectation)"]}, {"cell_type": "markdown", "metadata": {}, "source": ["The `cost_and_grad` function returns a tuple with the exact cost value and exact gradient evaluated at the parameters."]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["cost_and_grad(params)"]}, {"cell_type": "markdown", "metadata": {}, "source": ["# Now we have all the tools we need to design our VQE!
\n", "We'll just use vanilla gradient descent with a constant stepsize"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["def vqe(init_param, n_steps, stepsize):\n", " params = jnp.zeros((n_steps, n_params))\n", " params = params.at[0].set(init_param)\n", " cost_vals = jnp.zeros(n_steps)\n", " cost_vals = cost_vals.at[0].set(param_to_expectation(init_param))\n", " for step in range(1, n_steps):\n", " cost_val, cost_grad = cost_and_grad(params[step - 1])\n", " cost_vals = cost_vals.at[step].set(cost_val)\n", " new_param = params[step - 1] - stepsize * cost_grad\n", " params = params.at[step].set(new_param)\n", " print('Iteration:', step, '\\tCost:', cost_val, end='\\r')\n", " print('\\n')\n", " return params, cost_vals"]}, {"cell_type": "markdown", "metadata": {}, "source": ["Ok enough talking, let's run (and whilst we're at it we'll time it too)"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["%time vqe_params, vqe_cost_vals = vqe(params, n_steps=250, stepsize=0.01)"]}, {"cell_type": "markdown", "metadata": {}, "source": ["Let's plot the results..."]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["plt.plot(vqe_cost_vals)\n", "plt.xlabel('Iteration')\n", "plt.ylabel('Cost');"]}, {"cell_type": "markdown", "metadata": {}, "source": ["Pretty good!"]}, {"cell_type": "markdown", "metadata": {}, "source": ["# `jax.jit` speedup
\n", "One last thing... We can significantly speed up the VQE above via the `jax.jit`. In our current implementation, the expensive `cost_and_grad` function is compiled to [XLA](https://www.tensorflow.org/xla) and then executed at each call. By invoking `jax.jit` we ensure that the function is compiled only once (on the first call) and then simply executed at each future call - this is much faster!"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["cost_and_grad = jit(cost_and_grad)"]}, {"cell_type": "markdown", "metadata": {}, "source": ["We'll demonstrate this using the second set of initial parameters we randomly generated (to be sure of no caching)."]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["%time new_vqe_params, new_vqe_cost_vals = vqe(new_params, n_steps=250, stepsize=0.01)"]}, {"cell_type": "markdown", "metadata": {}, "source": ["That's some speedup!
\n", "But let's also plot the training to be sure it converged correctly"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["plt.plot(new_vqe_cost_vals)\n", "plt.xlabel('Iteration')\n", "plt.ylabel('Cost');"]}], "metadata": {"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}, "language_info": {"codemirror_mode": {"name": "ipython", "version": 3}, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.4"}}, "nbformat": 4, "nbformat_minor": 2}
\ No newline at end of file
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# VQE example with pytket-qujax"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from jax import numpy as jnp, random, value_and_grad, jit\n",
+ "from pytket import Circuit\n",
+ "from pytket.circuit.display import render_circuit_jupyter\n",
+ "import matplotlib.pyplot as plt"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import qujax\n",
+ "from pytket.extensions.qujax.qujax_convert import tk_to_qujax"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We place barriers to stop tket automatically rearranging gates and we also store the number of circuit parameters as we'll need this later."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def get_circuit(n_qubits, depth):\n",
+ " n_params = 2 * n_qubits * (depth + 1)\n",
+ " param = jnp.zeros((n_params,))\n",
+ " circuit = Circuit(n_qubits)\n",
+ " k = 0\n",
+ " for i in range(n_qubits):\n",
+ " circuit.H(i)\n",
+ " for i in range(n_qubits):\n",
+ " circuit.Rx(param[k], i)\n",
+ " k += 1\n",
+ " for i in range(n_qubits):\n",
+ " circuit.Ry(param[k], i)\n",
+ " k += 1\n",
+ " for _ in range(depth):\n",
+ " for i in range(0, n_qubits - 1):\n",
+ " circuit.CZ(i, i + 1)\n",
+ " circuit.add_barrier(range(0, n_qubits))\n",
+ " for i in range(n_qubits):\n",
+ " circuit.Rx(param[k], i)\n",
+ " k += 1\n",
+ " for i in range(n_qubits):\n",
+ " circuit.Ry(param[k], i)\n",
+ " k += 1\n",
+ " return circuit, n_params"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "n_qubits = 4\n",
+ "depth = 2\n",
+ "circuit, n_params = get_circuit(n_qubits, depth)\n",
+ "render_circuit_jupyter(circuit)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Now let's invoke qujax
\n",
+ "The `pytket.extensions.qujax.tk_to_qujax` function will generate a parameters -> statetensor function for us."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "param_to_st = tk_to_qujax(circuit)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Let's try it out on some random parameters values. Be aware that's JAX's random number generator requires a `jax.random.PRNGkey` every time it's called - more info on that [here](https://jax.readthedocs.io/en/latest/jax.random.html).
\n",
+ "Be aware that we still have convention where parameters are specified as multiples of $\\pi$ - that is in [0,2]."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "params = random.uniform(random.PRNGKey(0), shape=(n_params,), minval=0.0, maxval=2.0)\n",
+ "statetensor = param_to_st(params)\n",
+ "print(statetensor)\n",
+ "print(statetensor.shape)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Note that this function also has an optional second argument where an initiating `statetensor_in` can be provided. If it is not provided it will default to the all 0s state (as we use here)."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We can obtain statevector by simply calling `.flatten()`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "statevector = statetensor.flatten()\n",
+ "statevector.shape"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "And sampling probabilities by squaring the absolute value of the statevector"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "sample_probs = jnp.square(jnp.abs(statevector))\n",
+ "plt.bar(jnp.arange(statevector.size), sample_probs)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Cost function"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now we have our `param_to_st` function we are free to define a cost function that acts on bitstrings (e.g. maxcut) or integers by directly wrapping a function around `param_to_st`. However, cost functions defined via quantum Hamiltonians are a bit more involved.
\n",
+ "Fortunately, we can encode an Hamiltonian in JAX via the `qujax.get_statetensor_to_expectation_func` function which generates a statetensor -> expected value function for us.
\n",
+ "It takes three arguments as input
\n",
+ "- `gate_seq_seq`: A list of string (or array) lists encoding the gates in each term of the Hamiltonian. I.e. `[['X','X'], ['Y','Y'], ['Z','Z']]` corresponds to $H = aX_iX_j + bY_kY_l + cZ_mZ_n$ with qubit indices $i,j,k,l,m,n$ specified in the second argument and coefficients $a,b,c$ specified in the third argument
\n",
+ "- `qubit_inds_seq`: A list of integer lists encoding which qubit indices to apply the aforementioned gates. I.e. `[[0, 1],[0,1],[0,1]]`. Must have the same structure as `gate_seq_seq` above.
\n",
+ "- `coefficients`: A list of floats encoding any coefficients in the Hamiltonian. I.e. `[2.3, 0.8, 1.2]` corresponds to $a=2.3,b=0.8,c=1.2$ above. Must have the same length as the two above arguments."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "More specifically let's consider the problem of finding the ground state of the quantum Heisenberg Hamiltonian
\n",
+ "$$ H = \\sum_{i=1}^{n_\\text{qubits}-1} X_i X_{i+1} + Y_i Y_{i+1} + Z_i Z_{i+1}. $$
\n",
+ "As described, we define the Hamiltonian via its gate strings, qubit indices and coefficients."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "hamiltonian_gates = [[\"X\", \"X\"], [\"Y\", \"Y\"], [\"Z\", \"Z\"]] * (n_qubits - 1)\n",
+ "hamiltonian_qubit_inds = [\n",
+ " [int(i), int(i) + 1] for i in jnp.repeat(jnp.arange(n_qubits), 3)\n",
+ "]\n",
+ "coefficients = [1.0] * len(hamiltonian_qubit_inds)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print(\"Gates:\\t\", hamiltonian_gates)\n",
+ "print(\"Qubits:\\t\", hamiltonian_qubit_inds)\n",
+ "print(\"Coefficients:\\t\", coefficients)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now let's get the Hamiltonian as a pure JAX function"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "st_to_expectation = qujax.get_statetensor_to_expectation_func(\n",
+ " hamiltonian_gates, hamiltonian_qubit_inds, coefficients\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Let's check it works on the statetensor we've already generated."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "expected_val = st_to_expectation(statetensor)\n",
+ "expected_val"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now let's wrap the `param_to_st` and `st_to_expectation` together to give us an all in one `param_to_expectation` cost function."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "param_to_expectation = lambda param: st_to_expectation(param_to_st(param))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "param_to_expectation(params)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Sanity check that a different, randomly generated set of parameters gives us a new expected value."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "new_params = random.uniform(\n",
+ " random.PRNGKey(1), shape=(n_params,), minval=0.0, maxval=2.0\n",
+ ")\n",
+ "param_to_expectation(new_params)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# We can now use autodiff for fast, exact gradients within a VQE algorithm
\n",
+ "The `param_to_expectation` function we created is a pure JAX function and outputs a scalar. This means we can pass it to `jax.grad` (or even better `jax.value_and_grad`)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "cost_and_grad = value_and_grad(param_to_expectation)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The `cost_and_grad` function returns a tuple with the exact cost value and exact gradient evaluated at the parameters."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "cost_and_grad(params)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Now we have all the tools we need to design our VQE!
\n",
+ "We'll just use vanilla gradient descent with a constant stepsize"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def vqe(init_param, n_steps, stepsize):\n",
+ " params = jnp.zeros((n_steps, n_params))\n",
+ " params = params.at[0].set(init_param)\n",
+ " cost_vals = jnp.zeros(n_steps)\n",
+ " cost_vals = cost_vals.at[0].set(param_to_expectation(init_param))\n",
+ " for step in range(1, n_steps):\n",
+ " cost_val, cost_grad = cost_and_grad(params[step - 1])\n",
+ " cost_vals = cost_vals.at[step].set(cost_val)\n",
+ " new_param = params[step - 1] - stepsize * cost_grad\n",
+ " params = params.at[step].set(new_param)\n",
+ " print(\"Iteration:\", step, \"\\tCost:\", cost_val, end=\"\\r\")\n",
+ " print(\"\\n\")\n",
+ " return params, cost_vals"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Ok enough talking, let's run (and whilst we're at it we'll time it too)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "%time"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "vqe_params, vqe_cost_vals = vqe(params, n_steps=250, stepsize=0.01)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Let's plot the results..."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "plt.plot(vqe_cost_vals)\n",
+ "plt.xlabel(\"Iteration\")\n",
+ "plt.ylabel(\"Cost\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Pretty good!"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# `jax.jit` speedup
\n",
+ "One last thing... We can significantly speed up the VQE above via the `jax.jit`. In our current implementation, the expensive `cost_and_grad` function is compiled to [XLA](https://www.tensorflow.org/xla) and then executed at each call. By invoking `jax.jit` we ensure that the function is compiled only once (on the first call) and then simply executed at each future call - this is much faster!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "cost_and_grad = jit(cost_and_grad)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We'll demonstrate this using the second set of initial parameters we randomly generated (to be sure of no caching)."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "%time"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "new_vqe_params, new_vqe_cost_vals = vqe(new_params, n_steps=250, stepsize=0.01)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "That's some speedup!
\n",
+ "But let's also plot the training to be sure it converged correctly"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "plt.plot(new_vqe_cost_vals)\n",
+ "plt.xlabel(\"Iteration\")\n",
+ "plt.ylabel(\"Cost\")"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.6"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}