Skip to content

Commit

Permalink
Add support for ClExprOp (#251)
Browse files Browse the repository at this point in the history
  • Loading branch information
cqc-alec authored Nov 4, 2024
1 parent 24880e4 commit 8117266
Show file tree
Hide file tree
Showing 7 changed files with 282 additions and 3 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

* Ensure unused classical registers are not omitted

### Added

* PHIR generation for `ClExprOp`

## [0.8.1] - 2024-09-11

### Fixed
Expand Down
113 changes: 112 additions & 1 deletion pytket/phir/phirgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import pytket
import pytket.circuit as tk
from phir.model import PHIRModel
from pytket.circuit import ClBitVar, ClExpr, ClOp, ClRegVar
from pytket.circuit.clexpr import has_reg_output
from pytket.circuit.logic_exp import (
BitLogicExp,
BitWiseOp,
Expand All @@ -40,7 +42,7 @@
if TYPE_CHECKING:
from collections.abc import Sequence

from pytket.circuit import Circuit
from pytket.circuit import Circuit, WiredClExpr
from pytket.unit_id import UnitID

from .sharding.shard import Cost, Ordering, ShardLayer
Expand Down Expand Up @@ -383,12 +385,98 @@ def nested_cop(cop: str, args: "deque[UnitID]", val_bits: deque[int]) -> JsonDic
return nested_cop("&", deque(args), deque(map(int, f"{value:0{len(args)}b}")))


def get_cop_from_op(op: ClOp) -> str | int: # noqa: PLR0912
"""Get PHIR classical op name from ClOp."""
cop: str | int
match op:
case ClOp.BitZero | ClOp.RegZero:
cop = 0
case ClOp.BitOne:
cop = 1
case ClOp.RegOne:
cop = -1
case ClOp.BitAnd | ClOp.RegAnd:
cop = "&"
case ClOp.BitOr | ClOp.RegOr:
cop = "|"
case ClOp.BitXor | ClOp.RegXor:
cop = "^"
case ClOp.BitNot | ClOp.RegNot:
cop = "~"
case ClOp.RegLsh:
cop = "<<"
case ClOp.RegRsh:
cop = ">>"
case ClOp.BitEq | ClOp.RegEq:
cop = "=="
case ClOp.BitNeq | ClOp.RegNeq:
cop = "!="
case ClOp.RegLt:
cop = "<"
case ClOp.RegGt:
cop = ">"
case ClOp.RegLeq:
cop = "<="
case ClOp.RegGeq:
cop = ">="
case ClOp.RegAdd:
cop = "+"
case ClOp.RegSub:
cop = "-"
case ClOp.RegMul:
cop = "*"
case ClOp.RegDiv:
cop = "/"
case ClOp.RegPow:
cop = "**"
case _:
logging.exception("Classical operation %s unsupported by PHIR", str(op))
raise NotImplementedError(op)
return cop


def phir_from_clexpr_arg(
expr_arg: int | ClBitVar | ClRegVar | ClExpr,
bit_posn: dict[int, int],
reg_posn: dict[int, list[int]],
bits: list[tkBit],
) -> int | str | list[str | int] | JsonDict:
"""Return PHIR dict for a ClExpr."""
match expr_arg:
case int():
return expr_arg
case ClBitVar():
bit: tkBit = bits[bit_posn[expr_arg.index]]
return arg_to_bit(bit)
case ClRegVar():
bits_in_reg = [bits[i] for i in reg_posn[expr_arg.index]]
reg_size = len(bits_in_reg)
if reg_size == 0:
logging.exception("Register variable with no bits")
reg_name = bits_in_reg[0].reg_name
if any(bit.reg_name != reg_name for bit in bits_in_reg) or any(
bit.index[0] != i for i, bit in enumerate(bits_in_reg)
):
logging.exception("Register variable not aligned with any register")
return reg_name
assert isinstance(expr_arg, ClExpr) # noqa: S101

cop = get_cop_from_op(expr_arg.op)
if isinstance(cop, int):
return cop
args = [
phir_from_clexpr_arg(arg, bit_posn, reg_posn, bits) for arg in expr_arg.args
]
return {"cop": cop, "args": args}


def convert_subcmd(op: tk.Op, cmd: tk.Command) -> JsonDict | None: # noqa: PLR0912
"""Return PHIR dict given a tket op and its arguments."""
if op.is_gate():
return convert_gate(op, cmd)

out: JsonDict | None = None
rhs: list[int | str | list[str | int] | JsonDict] = []
match op: # non-quantum op
case tk.Conditional():
out = {
Expand Down Expand Up @@ -430,6 +518,24 @@ def convert_subcmd(op: tk.Op, cmd: tk.Command) -> JsonDict | None: # noqa: PLR0
rhs = [classical_op(exp)]
out = assign_cop([cmd.bits[0].reg_name], rhs)

case tk.ClExprOp():
wexpr: WiredClExpr = op.expr
expr: ClExpr = wexpr.expr
bit_posn: dict[int, int] = wexpr.bit_posn
reg_posn: dict[int, list[int]] = wexpr.reg_posn
output_posn: list[int] = wexpr.output_posn
cmd_args: list[tkBit] = cmd.bits

# TODO(AE): Check that all ClExprOps in the circuit are register-aligned
# (i.e. that each register variable, and the register output if applicable,
# comprises bits that constitute a complete register in the correct order).
# https://github.com/CQCL/tket/issues/1644

rhs = [phir_from_clexpr_arg(expr, bit_posn, reg_posn, cmd_args)]
if has_reg_output(expr.op):
return assign_cop([cmd_args[output_posn[0]].reg_name], rhs)
return assign_cop([arg_to_bit(cmd_args[output_posn[0]])], rhs)

case tk.ClassicalEvalOp():
return convert_classicalevalop(op, cmd)

Expand Down Expand Up @@ -544,6 +650,11 @@ def make_comment_text(cmd: tk.Command, op: tk.Op) -> str:
case RegLogicExp():
comment = str(cmd.bits[0].reg_name) + " = " + str(op.get_exp())

case tk.ClExprOp():
comment = (
str(cmd).split(";")[0] + " of the form " + str(op.expr).split(" [")[0]
)

return comment


Expand Down
1 change: 1 addition & 0 deletions pytket/phir/sharding/sharder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
OpType.Barrier,
OpType.SetBits,
OpType.ClassicalExpBox, # some classical operations are rolled up into a box
OpType.ClExpr,
OpType.RangePredicate,
OpType.ExplicitPredicate,
OpType.ExplicitModifier,
Expand Down
23 changes: 23 additions & 0 deletions tests/data/qasm/classical0.qasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
OPENQASM 2.0;
include "hqslib1.inc";

qreg q[1];
creg c[4];
creg a[2];
creg b[3];
creg d[1];

c = 2;
c = a;
if (b != 2) c[1] = b[1] & a[1] | a[0];
c = b & a;
b = a + b;
b[1] = b[0] + ~b[2];
c = a - (b*c);
d = a << 1;
d = c >> 2;
c[0] = 1;
b = a * c * b;
d[0] = a[0] ^ 1;
if(c>=2) h q[0];
if(d == 1) rx((0.5+0.5)*pi) q[0];
20 changes: 20 additions & 0 deletions tests/data/qasm/classical1.qasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
OPENQASM 2.0;
include "hqslib1.inc";

qreg q[1];
creg c[4];
creg a[2];
creg b[3];
creg d[1];

c = 2;
c = a;
if (b != 2) c[1] = b[1] & a[1] | a[0];
c = b & a | d;
d[0] = a[0] ^ 1;
if(c>=2) h q[0];
if(c<=2) h q[0];
if(c<2) h q[0];
if(c>2) h q[0];
if(c!=2) h q[0];
if(d == 1) rx((0.5+0.5)*pi) q[0];
111 changes: 111 additions & 0 deletions tests/test_phirgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

import json

import pytest

from pytket.circuit import Bit, Circuit
from pytket.circuit.logic_exp import BitWiseOp, create_bit_logic_exp
from pytket.phir.api import pytket_to_phir
Expand Down Expand Up @@ -500,3 +502,112 @@ def test_unused_classical_registers() -> None:
"size": 1,
"variable": "a",
}


@pytest.mark.parametrize("use_clexpr", [False, True])
def test_classical_0(*, use_clexpr: bool) -> None:
"""Test handling of ClassicalExpBox/ClExprOp."""
circ = get_qasm_as_circuit(QasmFile.classical0, use_clexpr=use_clexpr)
phir = json.loads(pytket_to_phir(circ))
ops = phir["ops"]
assert {
"cop": "=",
"returns": ["d"],
"args": [{"cop": "<<", "args": ["a", 1]}],
} in ops
assert {
"block": "if",
"condition": {"cop": "==", "args": [["tk_SCRATCH_BIT", 0], 0]},
"true_branch": [
{
"cop": "=",
"returns": [["c", 1]],
"args": [
{
"cop": "|",
"args": [
{"cop": "&", "args": [["b", 1], ["a", 1]]},
["a", 0],
],
}
],
}
],
} in ops
assert {
"cop": "=",
"returns": ["c"],
"args": [{"cop": "&", "args": ["b", "a"]}],
} in ops
assert {
"cop": "=",
"returns": ["b"],
"args": [{"cop": "+", "args": ["a", "b"]}],
} in ops
assert {
"cop": "=",
"returns": [["b", 1]],
"args": [{"cop": "^", "args": [["b", 0], {"cop": "~", "args": [["b", 2]]}]}],
} in ops
assert {
"cop": "=",
"returns": ["c"],
"args": [{"cop": "-", "args": ["a", {"cop": "*", "args": ["b", "c"]}]}],
} in ops
assert {
"cop": "=",
"returns": ["d"],
"args": [{"cop": "<<", "args": ["a", 1]}],
} in ops
assert {
"cop": "=",
"returns": ["d"],
"args": [{"cop": ">>", "args": ["c", 2]}],
} in ops
assert {
"cop": "=",
"returns": ["b"],
"args": [{"cop": "*", "args": [{"cop": "*", "args": ["a", "c"]}, "b"]}],
} in ops
assert {
"cop": "=",
"returns": [["d", 0]],
"args": [{"cop": "^", "args": [["a", 0], 1]}],
} in ops


@pytest.mark.parametrize("use_clexpr", [False, True])
def test_classical_1(*, use_clexpr: bool) -> None:
"""Test handling of ClassicalExpBox/ClExprOp."""
circ = get_qasm_as_circuit(QasmFile.classical1, use_clexpr=use_clexpr)
phir = json.loads(pytket_to_phir(circ))
ops = phir["ops"]
assert {
"block": "if",
"condition": {"cop": "==", "args": [["tk_SCRATCH_BIT", 0], 0]},
"true_branch": [
{
"cop": "=",
"returns": [["c", 1]],
"args": [
{
"cop": "|",
"args": [
{"cop": "&", "args": [["b", 1], ["a", 1]]},
["a", 0],
],
}
],
}
],
} in ops
assert {
"cop": "=",
"returns": ["c"],
"args": [{"cop": "|", "args": [{"cop": "&", "args": ["b", "a"]}, "d"]}],
} in ops
assert {
"cop": "=",
"returns": [["d", 0]],
"args": [{"cop": "^", "args": [["a", 0], 1]}],
} in ops
13 changes: 11 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,25 +53,34 @@ class QasmFile(Enum):
arbitrary_qreg_names = auto()
group_ordering = auto()
sleep = auto()
classical0 = auto()
classical1 = auto()


class WatFile(Enum):
add = auto()
testfile = auto()


def get_qasm_as_circuit(qasm_file: QasmFile) -> "Circuit":
def get_qasm_as_circuit(
qasm_file: QasmFile,
*,
use_clexpr: bool = False,
) -> "Circuit":
"""Utility function to convert a QASM file to Circuit.
Args:
qasm_file: enum for a QASM file
use_clexpr: convert classical expressions to ClExprOp operations
Returns:
Corresponding tket circuit
"""
this_dir = Path(Path(__file__).resolve()).parent
return circuit_from_qasm(
f"{this_dir}/data/qasm/{qasm_file.name}.qasm", maxwidth=WORDSIZE
f"{this_dir}/data/qasm/{qasm_file.name}.qasm",
maxwidth=WORDSIZE,
use_clexpr=use_clexpr,
)


Expand Down

0 comments on commit 8117266

Please sign in to comment.