Skip to content

Commit

Permalink
Bug fixes in serialization to the abstract representation (#467)
Browse files Browse the repository at this point in the history
* Fix default basis measurement serialization error

* Fix missing default value for variables

* Bump to version 0.9.3

* Dynamically getting the kwarg default values
  • Loading branch information
HGSilveri authored Feb 24, 2023
1 parent 2decdff commit fee422d
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 22 deletions.
2 changes: 1 addition & 1 deletion VERSION.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.9.2
0.9.3
44 changes: 25 additions & 19 deletions pulser-core/pulser/json/abstract_repr/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Utility functions for JSON serialization to the abstract representation."""
from __future__ import annotations

import inspect
import json
from itertools import chain
from typing import TYPE_CHECKING, Any
Expand Down Expand Up @@ -143,6 +144,11 @@ def serialize_abstract_sequence(
for var in seq._variables.values():
value = var._validate_value(defaults[var.name])
res["variables"][var.name]["value"] = value.tolist()
else:
# Still need to set a default value for the variables because the
# deserializer uses it to infer the size of the variable
for var in seq._variables.values():
res["variables"][var.name]["value"] = [var.dtype()] * var.size

def convert_targets(
target_ids: Union[QubitId, abcSequence[QubitId]]
Expand All @@ -154,10 +160,20 @@ def convert_targets(
indices = seq.register.find_indices(target_array.tolist())
return indices[0] if og_dim == 0 else indices

def get_kwarg_default(call_name: str, kwarg_name: str) -> Any:
sig = inspect.signature(getattr(seq, call_name))
return sig.parameters[kwarg_name].default

def get_all_args(
pos_args_signature: tuple[str, ...], call: _Call
) -> dict[str, Any]:
return {**dict(zip(pos_args_signature, call.args)), **call.kwargs}
params = {**dict(zip(pos_args_signature, call.args)), **call.kwargs}
default_values = {
p_name: get_kwarg_default(call.name, p_name)
for p_name in pos_args_signature
if p_name not in params
}
return {**default_values, **params}

operations = res["operations"]
for call in chain(seq._calls, seq._to_build_calls):
Expand All @@ -180,7 +196,7 @@ def get_all_args(
("channel", "channel_id", "initial_target"), call
)
res["channels"][data["channel"]] = data["channel_id"]
if "initial_target" in data and data["initial_target"] is not None:
if data["initial_target"] is not None:
operations.append(
{
"op": "target",
Expand Down Expand Up @@ -222,30 +238,24 @@ def get_all_args(
op_dict = {
"op": "pulse",
"channel": data["channel"],
"protocol": "min-delay"
if "protocol" not in data
else data["protocol"],
"protocol": data["protocol"],
}
op_dict.update(data["pulse"]._to_abstract_repr())
operations.append(op_dict)
elif "phase_shift" in call.name:
try:
basis = call.kwargs["basis"]
except KeyError:
basis = "digital"
targets = call.args[1:]
if call.name == "phase_shift":
targets = convert_targets(targets)
elif call.name == "phase_shift_index":
pass
else:
elif call.name != "phase_shift_index":
raise AbstractReprError(f"Unknown call '{call.name}'.")
operations.append(
{
"op": "phase_shift",
"phi": call.args[0],
"targets": targets,
"basis": basis,
"basis": call.kwargs.get(
"basis", get_kwarg_default(call.name, "basis")
),
}
)
elif call.name == "set_magnetic_field":
Expand All @@ -257,9 +267,7 @@ def get_all_args(
("channel", "amp_on", "detuning_on", "optimal_detuning_off"),
call,
)
# Overwritten if in 'data'
defaults = dict(optimal_detuning_off=0.0)
operations.append({"op": "enable_eom_mode", **defaults, **data})
operations.append({"op": "enable_eom_mode", **data})
elif call.name == "add_eom_pulse":
data = get_all_args(
(
Expand All @@ -271,9 +279,7 @@ def get_all_args(
),
call,
)
# Overwritten if in 'data'
defaults = dict(post_phase_shift=0.0, protocol="min-delay")
operations.append({"op": "add_eom_pulse", **defaults, **data})
operations.append({"op": "add_eom_pulse", **data})
elif call.name == "disable_eom_mode":
data = get_all_args(("channel",), call)
operations.append(
Expand Down
59 changes: 57 additions & 2 deletions tests/test_abstract_repr.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def test_mappable_register(self, triangular_lattice):
"slug": triangular_lattice.slug,
}
assert abstract["register"] == [{"qid": qid} for qid in reg.qubit_ids]
assert abstract["variables"]["var"] == dict(type="int")
assert abstract["variables"]["var"] == dict(type="int", value=[0])

with pytest.raises(
ValueError,
Expand Down Expand Up @@ -504,6 +504,46 @@ def test_eom_mode(self, triangular_lattice):
"channel": "ryd",
}

@pytest.mark.parametrize("use_default", [True, False])
def test_default_basis(
self, triangular_lattice: TriangularLatticeLayout, use_default
):
phase_kwargs = {} if use_default else dict(basis="ground-rydberg")
measure_kwargs = {} if use_default else dict(basis="digital")

seq = Sequence(triangular_lattice.hexagonal_register(5), Chadoq2)
seq.declare_channel("ryd", "rydberg_global")
seq.declare_channel("raman", "raman_local", initial_target="q0")
seq.phase_shift(1, "q0", **phase_kwargs)
seq.phase_shift_index(2, 1, **phase_kwargs)
seq.measure(**measure_kwargs)

abstract = json.loads(seq.to_abstract_repr())
validate_schema(abstract)
assert len(abstract["operations"]) == 3

assert abstract["operations"][0] == {
"op": "target",
"channel": "raman",
"target": 0,
}

assert abstract["operations"][1] == {
"op": "phase_shift",
"basis": phase_kwargs.get("basis", "digital"),
"targets": [0],
"phi": 1,
}
assert abstract["operations"][2] == {
"op": "phase_shift",
"basis": phase_kwargs.get("basis", "digital"),
"targets": [1],
"phi": 2,
}
assert abstract["measurement"] == measure_kwargs.get(
"basis", "ground-rydberg"
)


def _get_serialized_seq(
operations: list[dict] = None,
Expand Down Expand Up @@ -565,6 +605,17 @@ def _check_roundtrip(serialized_seq: dict[str, Any]):
)
assert s == json.loads(rs)

# Remove the defaults and check it still works
for var in seq.declared_variables.values():
s["variables"][var.name]["value"] = [var.dtype()] * var.size
for q in s["register"]:
q.pop("default_trap", None)
s["name"] = "pulser-exported"

seq2 = Sequence.from_abstract_repr(json.dumps(s))
rs_no_defaults = seq2.to_abstract_repr()
assert s == json.loads(rs_no_defaults)


# Needed to replace lambdas in the pytest.mark.parametrize calls (due to mypy)
def _get_op(op: dict) -> Any:
Expand Down Expand Up @@ -660,7 +711,8 @@ def test_deserialize_seq_with_mag_field(self):
seq = Sequence.from_abstract_repr(json.dumps(s))
assert np.all(seq.magnetic_field == mag_field)

def test_deserialize_variables(self):
@pytest.mark.parametrize("without_default", [True, False])
def test_deserialize_variables(self, without_default):
s = _get_serialized_seq(
variables={
"yolo": {"type": "int", "value": [42, 43, 44]},
Expand All @@ -669,6 +721,9 @@ def test_deserialize_variables(self):
)
_check_roundtrip(s)
seq = Sequence.from_abstract_repr(json.dumps(s))
if without_default:
# Serialize and deserialize again, without the defaults
seq = Sequence.from_abstract_repr(seq.to_abstract_repr())

# Check variables
assert len(seq.declared_variables) == len(s["variables"])
Expand Down

0 comments on commit fee422d

Please sign in to comment.