Skip to content

Commit

Permalink
fix markdown headings in qujax example
Browse files Browse the repository at this point in the history
  • Loading branch information
CalMacCQ committed Oct 26, 2023
1 parent 3090a0d commit 68e282b
Show file tree
Hide file tree
Showing 2 changed files with 491 additions and 7 deletions.
14 changes: 8 additions & 6 deletions examples/python/pytket-qujax_heisenberg_vqe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
from pytket.circuit.display import render_circuit_jupyter
import matplotlib.pyplot as plt


# ## Let's start with a TKET circuit

import qujax
from pytket.extensions.qujax.qujax_convert import tk_to_qujax

# # Let's start with a tket circuit
# We place barriers to stop tket automatically rearranging gates and we also store the number of circuit parameters as we'll need this later.


Expand Down Expand Up @@ -47,7 +49,7 @@ def get_circuit(n_qubits, depth):
circuit, n_params = get_circuit(n_qubits, depth)
render_circuit_jupyter(circuit)

# # Now let's invoke qujax
# ## Now let's invoke qujax
# The `pytket.extensions.qujax.tk_to_qujax` function will generate a parameters -> statetensor function for us.

param_to_st = tk_to_qujax(circuit)
Expand All @@ -72,7 +74,7 @@ def get_circuit(n_qubits, depth):
sample_probs = jnp.square(jnp.abs(statevector))
plt.bar(jnp.arange(statevector.size), sample_probs)

# # Cost function
# ## Cost function

# Now we have our `param_to_st` function we are free to define a cost function that acts on bitstrings (e.g. maxcut) or integers by directly wrapping a function around `param_to_st`. However, cost functions defined via quantum Hamiltonians are a bit more involved.
# Fortunately, we can encode an Hamiltonian in JAX via the `qujax.get_statetensor_to_expectation_func` function which generates a statetensor -> expected value function for us.
Expand Down Expand Up @@ -119,7 +121,7 @@ def get_circuit(n_qubits, depth):
)
param_to_expectation(new_params)

# # We can now use autodiff for fast, exact gradients within a VQE algorithm
# ## Exact gradients within a VQE algorithm
# The `param_to_expectation` function we created is a pure JAX function and outputs a scalar. This means we can pass it to `jax.grad` (or even better `jax.value_and_grad`).

cost_and_grad = value_and_grad(param_to_expectation)
Expand All @@ -128,7 +130,7 @@ def get_circuit(n_qubits, depth):

cost_and_grad(params)

# # Now we have all the tools we need to design our VQE!
# ## Now we have all the tools we need to design our VQE!
# We'll just use vanilla gradient descent with a constant stepsize


Expand Down Expand Up @@ -164,7 +166,7 @@ def vqe(init_param, n_steps, stepsize):

# Pretty good!

# # `jax.jit` speedup
# ## `jax.jit` speedup
# One last thing... We can significantly speed up the VQE above via the `jax.jit`. In our current implementation, the expensive `cost_and_grad` function is compiled to [XLA](https://www.tensorflow.org/xla) and then executed at each call. By invoking `jax.jit` we ensure that the function is compiled only once (on the first call) and then simply executed at each future call - this is much faster!

cost_and_grad = jit(cost_and_grad)
Expand Down
Loading

0 comments on commit 68e282b

Please sign in to comment.