Skip to content

Commit

Permalink
Functional version with qadence, still mypy issues
Browse files Browse the repository at this point in the history
  • Loading branch information
dominikandreasseitz committed Jan 23, 2024
1 parent dd4aa7a commit 486d828
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 77 deletions.
4 changes: 2 additions & 2 deletions horqrux/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations

from .parametric import Rx, Ry, Rz
from .parametric import PHASE, RX, RY, RZ
from .primitive import NOT, SWAP, H, I, S, T, X, Y, Z
from .utils import prepare_state
from .utils import hilbert_reshape, overlap, prepare_state, uniform_state
60 changes: 46 additions & 14 deletions horqrux/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,58 +3,80 @@
from dataclasses import dataclass
from typing import Any, Iterable, Tuple

import numpy as np
from jax import Array
from jax.tree_util import register_pytree_node_class

from .matrices import OPERATIONS_DICT
from .utils import QubitSupport, _dagger, _jacobian, _unitary
from .utils import QubitSupport, _dagger, _jacobian, _unitary, is_controlled, none_like


@register_pytree_node_class
@dataclass
class Operator:
name: str
generator_name: str
target: QubitSupport
control: QubitSupport

def __post_init__(self) -> None:
def _parse(idx: QubitSupport | Tuple[None, ...]) -> QubitSupport:
return (idx,) if isinstance(idx, int) or idx is None else idx
@staticmethod
def parse_idx(
idx: Tuple,
) -> Tuple:
if isinstance(idx, (int, np.int64)):
return ((idx,),)
elif isinstance(idx, tuple):
return (idx,)
else:
return (idx.astype(int),)

self.target, self.control = list(map(_parse, (self.target, self.control)))
def __post_init__(self) -> None:
self.target = Operator.parse_idx(self.target)
if self.control is None:
self.control = none_like(self.target)
else:
self.control = Operator.parse_idx(self.control)

def __iter__(self) -> Iterable:
return iter((self.name, self.target, self.control))
return iter((self.generator_name, self.target, self.control))

def tree_flatten(self) -> Tuple[Tuple, Tuple[str, QubitSupport, QubitSupport]]:
children = ()
aux_data = (self.name, self.target, self.control)
aux_data = (self.generator_name, self.target, self.control)
return (children, aux_data)

@classmethod
def tree_unflatten(cls, aux_data: Any, children: Any) -> Any:
return cls(*children, *aux_data)

def unitary(self, values: dict[str, float] = {}) -> Array:
return OPERATIONS_DICT[self.name]
return OPERATIONS_DICT[self.generator_name]

def dagger(self, values: dict[str, float] = {}) -> Array:
return _dagger(self.unitary(values))

@property
def name(self) -> str:
return "C" + self.generator_name if is_controlled(self.control) else self.generator_name

def __repr__(self) -> str:
return self.name + f"(target={self.target[0]}, control={self.control[0]})"


Primitive = Operator


@register_pytree_node_class
@dataclass
class Parametric(Primitive):
name: str
generator_name: str
target: QubitSupport
control: QubitSupport
param: str | int = ""
param: str | float = ""

def __post_init__(self) -> None:
def parse_dict(self, values: dict[str, float] | float = {}) -> float:
super().__post_init__()

def parse_dict(values: dict[str, float] = {}) -> float:
return values[self.param]

self.parse_values = parse_dict if isinstance(self.param, str) else lambda x: self.param
Expand All @@ -70,7 +92,17 @@ def tree_flatten(self) -> Tuple[Tuple, Tuple[str, QubitSupport, QubitSupport, st
return (children, aux_data)

def unitary(self, values: dict[str, float] = {}) -> Array:
return _unitary(OPERATIONS_DICT[self.name], self.parse_values(values))
return _unitary(OPERATIONS_DICT[self.generator_name], self.parse_values(values))

def jacobian(self, values: dict[str, float] = {}) -> Array:
return _jacobian(OPERATIONS_DICT[self.name], values[self.param])
return _jacobian(OPERATIONS_DICT[self.generator_name], self.parse_values(values))

@property
def name(self) -> str:
base_name = "R" + self.generator_name
return "C" + base_name if is_controlled(self.control) else base_name

def __repr__(self) -> str:
return (
self.name + f"(target={self.target[0]}, control={self.control[0]}, param={self.param})"
)
35 changes: 17 additions & 18 deletions horqrux/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,39 +10,38 @@

from horqrux.abstract import Operator

from .utils import QubitSupport, State, hilbert_reshape, make_controlled
from .utils import ControlQubits, State, TargetQubits, _controlled, is_controlled


def apply_operator(
state: State,
operator: Array,
target: QubitSupport,
control: QubitSupport,
target: TargetQubits,
control: ControlQubits,
) -> State:
"""Applies a single or series of operators to the given state. The operators 'operator' should
either be an array over whose first axis we can iterate (e.g. [N_gates, 2 x 2])
"""Applies a single or series of operators to the given state.
The operators are expected to either be an array over whose first axis we can iterate (e.g. [N_gates, 2 x 2])
or if you have a mix of single and multi qubit gates a tuple or list like [O_1, O_2, ...].
This function then sequentially applies this gates, adding control bits
as necessary and returning the state after applying all the gates.
Args:
state (Array): Input state to operate on.
operator (Union[Iterable, Array]): Iterable or array of operator matrixes to apply.
target_idx (TargetIdx): Target indices, Tuple of Tuple of ints.
control_idx (ControlIdx): Control indices, Tuple of length target_idex of None or Tuple.
state: Input state to operate on.
operator: List of arrays or array of operators to contract over the state.
target: Target indices, Tuple of Tuple of ints.
control: Control indices, Tuple of length target_idex of None or Tuple.
Returns:
Array: Changed state.
"""

target = (target,) if isinstance(target, int) else target
qubits = target
if control is not None:
control = (control,) if isinstance(control, int) else control
operator = make_controlled(operator, len(control))
qubits = control + target
operator = hilbert_reshape(operator) if len(target) > 1 else operator
assert isinstance(control, tuple)
if is_controlled(control):
operator = _controlled(operator, len(control))
qubits = (*control, *target)
n_qubits = int(np.log2(operator.size))
operator = operator.reshape(tuple(2 for _ in np.arange(n_qubits)))
op_dims = tuple(np.arange(operator.ndim // 2, operator.ndim, dtype=int))
state = jnp.tensordot(a=operator, b=state, axes=(op_dims, qubits))
return jnp.moveaxis(a=state, source=np.arange(len(qubits)), destination=qubits)
Expand All @@ -51,10 +50,10 @@ def apply_operator(
def apply_gate(
state: State, gate: Operator | Iterable[Operator], values: dict[str, float] = {}
) -> State:
"""Applies gate to given state. Essentially a simple wrapper around
"""Applies a gate to given state. Essentially a simple wrapper around
apply_operator, see that docstring for more info.
Args:
Arguments:
state (Array): State to operate on.
gate (Gate): Gate(s) to apply.
Expand Down
60 changes: 34 additions & 26 deletions horqrux/parametric.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,52 +7,60 @@
from .utils import ControlQubits, TargetQubits


def Rx(param: float | str, target: TargetQubits, control: ControlQubits = (None,)) -> Parametric:
"""Rx gate.
def RX(param: float | str, target: TargetQubits, control: ControlQubits = (None,)) -> Parametric:
"""RX gate.
Args:
theta (float): Rotational angle.
target (TargetIdx): Tuple of Tuples describing the qubits to apply to.
control (ControlIdx, optional): Tuple of Tuples or Nones describing
the control qubits of length(target). Defaults to (None,).
Arguments:
param: Parameter denoting the Rotational angle.
target: Tuple of target qubits denoted as ints.
control: Optional tuple of control qubits denoted as ints.
Returns:
Gate: Gate object.
Parametric: A Parametric gate object.
"""
return Parametric("X", target, control, param=param)


def Ry(param: float | str, target: TargetQubits, control: ControlQubits = (None,)) -> Parametric:
"""Ry gate.
def RY(param: float | str, target: TargetQubits, control: ControlQubits = (None,)) -> Parametric:
"""RY gate.
Args:
theta (float): Rotational angle.
target (TargetIdx): Tuple of Tuples describing the qubits to apply to.
control (ControlIdx, optional): Tuple of Tuples or Nones describing
the control qubits of length(target). Defaults to (None,).
Arguments:
param: Parameter denoting the Rotational angle.
target: Tuple of target qubits denoted as ints.
control: Optional tuple of control qubits denoted as ints.
Returns:
Gate: Gate object.
Parametric: A Parametric gate object.
"""
return Parametric("Y", target, control, param=param)


def Rz(param: float | str, target: TargetQubits, control: ControlQubits = (None,)) -> Parametric:
"""Rz gate.
def RZ(param: float | str, target: TargetQubits, control: ControlQubits = (None,)) -> Parametric:
"""RZ gate.
Args:
theta (float): Rotational angle.
target (TargetIdx): Tuple of Tuples describing the qubits to apply to.
control (ControlIdx, optional): Tuple of Tuples or Nones describing
the control qubits of length(target). Defaults to (None,).
Arguments:
param: Parameter denoting the Rotational angle.
target: Tuple of target qubits denoted as ints.
control: Optional tuple of control qubits denoted as ints.
Returns:
Gate: Gate object.
Parametric: A Parametric gate object.
"""
return Parametric("Z", target, control, param=param)


def Phase(param: float | str, target: TargetQubits, control: ControlQubits = (None,)) -> Parametric:
def PHASE(param: float, target: TargetQubits, control: ControlQubits = (None,)) -> Parametric:
"""Phase gate.
Arguments:
param: Parameter denoting the Rotational angle.
target: Tuple of target qubits denoted as ints.
control: Optional tuple of control qubits denoted as ints.
Returns:
Parametric: A Parametric gate object.
"""

def unitary(values: dict[str, float] = {}) -> Array:
u = jnp.eye(2, 2, dtype=jnp.complex128)
u = u.at[(1, 1)].set(jnp.exp(1.0j * values[param]))
Expand All @@ -63,7 +71,7 @@ def jacobian(values: dict[str, float] = {}) -> Array:
jac = jac.at[(1, 1)].set(1j * jnp.exp(1.0j * values[param]))
return jac

phase = Parametric("I", target, control)
phase = Parametric("I", target, control, param)
phase.name = "PHASE"
phase.unitary = unitary
phase.jacobian = jacobian
Expand Down
10 changes: 5 additions & 5 deletions horqrux/primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def I(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive:

def NOT(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive:
"""NOT gate. Note that since we lazily evaluate the circuit, this function
returns the gate representationt of Gate type and does *not* apply the gate.
returns the gate representation of Gate type and does *not* apply the gate.
By providing a control idx it turns into a controlled gate, use None for no control qubits.
Example usage: `NOT(((1, ), ), (None, ))` applies the NOT to qubit 1.
Expand All @@ -42,7 +42,7 @@ def NOT(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive:

def X(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive:
"""X gate. Note that since we lazily evaluate the circuit, this function
returns the gate representationt of Gate type and does *not* apply the gate.
returns the gate representation of Gate type and does *not* apply the gate.
By providing a control idx it turns into a controlled gate, use None for no control qubits.
Example usage: X(((1, ), ), (None, )) applies the NOT to qubit 1.
Expand All @@ -61,7 +61,7 @@ def X(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive:

def Y(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive:
"""Y gate. Note that since we lazily evaluate the circuit, this function
returns the gate representationt of Gate type and does *not* apply the gate.
returns the gate representation of Gate type and does *not* apply the gate.
By providing a control idx it turns into a controlled gate, use None for no control qubits.
Example usage: Y(((1, ), ), (None, )) applies the NOT to qubit 1.
Expand All @@ -80,7 +80,7 @@ def Y(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive:

def Z(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive:
"""Z gate. Note that since we lazily evaluate the circuit, this function
returns the gate representationt of Gate type and does *not* apply the gate.
returns the gate representation of Gate type and does *not* apply the gate.
By providing a control idx it turns into a controlled gate, use None for no control qubits.
Example usage: Z(((1, ), ), (None, )) applies the NOT to qubit 1.
Expand All @@ -99,7 +99,7 @@ def Z(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive:

def H(target: TargetQubits, control: ControlQubits = (None,)) -> Primitive:
"""H gate. Note that since we lazily evaluate the circuit, this function
returns the gate representationt of Gate type and does *not* apply the gate.
returns the gate representation of Gate type and does *not* apply the gate.
By providing a control idx it turns into a controlled gate, use None for no control qubits.
Example usage: H(((1, ), ), (None, )) applies the NOT to qubit 1.
Expand Down
33 changes: 25 additions & 8 deletions horqrux/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,34 @@

State = ArrayLike
QubitSupport = Tuple[Any, ...]
ControlQubits = QubitSupport
TargetQubits = QubitSupport
ControlQubits = Tuple[None | Tuple[int, ...], ...]
TargetQubits = Tuple[Tuple[int, ...], ...]


def _dagger(operator: Array) -> Array:
return jnp.conjugate(operator.T)


def _unitary(generator: Array, theta: float) -> Array:
return jnp.cos(theta / 2) * jnp.eye(2) - 1j * jnp.sin(theta / 2) * generator
return (
jnp.cos(theta / 2) * jnp.eye(2, dtype=jnp.complex128) - 1j * jnp.sin(theta / 2) * generator
)


def _jacobian(generator: Array, theta: float) -> Array:
return -1 / 2 * (jnp.sin(theta / 2) * jnp.eye(2) + 1j * jnp.cos(theta / 2)) * generator
return (
-1
/ 2
* (jnp.sin(theta / 2) * jnp.eye(2, dtype=jnp.complex128) + 1j * jnp.cos(theta / 2))
* generator
)


def make_controlled(operator: Array, n_control: int) -> Array:
def _controlled(operator: Array, n_control: int) -> Array:
n_qubits = int(log2(operator.shape[0]))
_controlled = jnp.eye(2 ** (n_control + n_qubits))
_controlled = _controlled.at[-(2**n_qubits) :, -(2**n_qubits) :].set(operator)
return hilbert_reshape(_controlled)
control = jnp.eye(2 ** (n_control + n_qubits), dtype=jnp.complex128)
control = control.at[-(2**n_qubits) :, -(2**n_qubits) :].set(operator)
return control


def prepare_state(n_qubits: int, state: str = None) -> Array:
Expand Down Expand Up @@ -104,3 +111,13 @@ def uniform_state(
state = jnp.ones(2**n_qubits, dtype=jnp.complex128)
state = state / jnp.sqrt(jnp.array(2**n_qubits))
return state.reshape([2] * n_qubits)


def is_controlled(qs: ControlQubits) -> bool:
if qs is None:
return False
if isinstance(qs, tuple):
return any(isinstance(q, int) for q in qs)
for s in qs:
return any([qubit is not None for qubit in s])
return False
Loading

0 comments on commit 486d828

Please sign in to comment.