diff --git a/VERSION.txt b/VERSION.txt index f76f91317..b3ec1638f 100644 --- a/VERSION.txt +++ b/VERSION.txt @@ -1 +1 @@ -0.9.2 \ No newline at end of file +0.9.3 \ No newline at end of file diff --git a/pulser-core/pulser/json/abstract_repr/serializer.py b/pulser-core/pulser/json/abstract_repr/serializer.py index 437096c9b..239b5e10f 100644 --- a/pulser-core/pulser/json/abstract_repr/serializer.py +++ b/pulser-core/pulser/json/abstract_repr/serializer.py @@ -14,6 +14,7 @@ """Utility functions for JSON serialization to the abstract representation.""" from __future__ import annotations +import inspect import json from itertools import chain from typing import TYPE_CHECKING, Any @@ -143,6 +144,11 @@ def serialize_abstract_sequence( for var in seq._variables.values(): value = var._validate_value(defaults[var.name]) res["variables"][var.name]["value"] = value.tolist() + else: + # Still need to set a default value for the variables because the + # deserializer uses it to infer the size of the variable + for var in seq._variables.values(): + res["variables"][var.name]["value"] = [var.dtype()] * var.size def convert_targets( target_ids: Union[QubitId, abcSequence[QubitId]] @@ -154,10 +160,20 @@ def convert_targets( indices = seq.register.find_indices(target_array.tolist()) return indices[0] if og_dim == 0 else indices + def get_kwarg_default(call_name: str, kwarg_name: str) -> Any: + sig = inspect.signature(getattr(seq, call_name)) + return sig.parameters[kwarg_name].default + def get_all_args( pos_args_signature: tuple[str, ...], call: _Call ) -> dict[str, Any]: - return {**dict(zip(pos_args_signature, call.args)), **call.kwargs} + params = {**dict(zip(pos_args_signature, call.args)), **call.kwargs} + default_values = { + p_name: get_kwarg_default(call.name, p_name) + for p_name in pos_args_signature + if p_name not in params + } + return {**default_values, **params} operations = res["operations"] for call in chain(seq._calls, seq._to_build_calls): @@ -180,7 +196,7 @@ def get_all_args( ("channel", "channel_id", "initial_target"), call ) res["channels"][data["channel"]] = data["channel_id"] - if "initial_target" in data and data["initial_target"] is not None: + if data["initial_target"] is not None: operations.append( { "op": "target", @@ -222,30 +238,24 @@ def get_all_args( op_dict = { "op": "pulse", "channel": data["channel"], - "protocol": "min-delay" - if "protocol" not in data - else data["protocol"], + "protocol": data["protocol"], } op_dict.update(data["pulse"]._to_abstract_repr()) operations.append(op_dict) elif "phase_shift" in call.name: - try: - basis = call.kwargs["basis"] - except KeyError: - basis = "digital" targets = call.args[1:] if call.name == "phase_shift": targets = convert_targets(targets) - elif call.name == "phase_shift_index": - pass - else: + elif call.name != "phase_shift_index": raise AbstractReprError(f"Unknown call '{call.name}'.") operations.append( { "op": "phase_shift", "phi": call.args[0], "targets": targets, - "basis": basis, + "basis": call.kwargs.get( + "basis", get_kwarg_default(call.name, "basis") + ), } ) elif call.name == "set_magnetic_field": @@ -257,9 +267,7 @@ def get_all_args( ("channel", "amp_on", "detuning_on", "optimal_detuning_off"), call, ) - # Overwritten if in 'data' - defaults = dict(optimal_detuning_off=0.0) - operations.append({"op": "enable_eom_mode", **defaults, **data}) + operations.append({"op": "enable_eom_mode", **data}) elif call.name == "add_eom_pulse": data = get_all_args( ( @@ -271,9 +279,7 @@ def get_all_args( ), call, ) - # Overwritten if in 'data' - defaults = dict(post_phase_shift=0.0, protocol="min-delay") - operations.append({"op": "add_eom_pulse", **defaults, **data}) + operations.append({"op": "add_eom_pulse", **data}) elif call.name == "disable_eom_mode": data = get_all_args(("channel",), call) operations.append( diff --git a/tests/test_abstract_repr.py b/tests/test_abstract_repr.py index ee99bc99e..1bb4c4e78 100644 --- a/tests/test_abstract_repr.py +++ b/tests/test_abstract_repr.py @@ -439,7 +439,7 @@ def test_mappable_register(self, triangular_lattice): "slug": triangular_lattice.slug, } assert abstract["register"] == [{"qid": qid} for qid in reg.qubit_ids] - assert abstract["variables"]["var"] == dict(type="int") + assert abstract["variables"]["var"] == dict(type="int", value=[0]) with pytest.raises( ValueError, @@ -504,6 +504,46 @@ def test_eom_mode(self, triangular_lattice): "channel": "ryd", } + @pytest.mark.parametrize("use_default", [True, False]) + def test_default_basis( + self, triangular_lattice: TriangularLatticeLayout, use_default + ): + phase_kwargs = {} if use_default else dict(basis="ground-rydberg") + measure_kwargs = {} if use_default else dict(basis="digital") + + seq = Sequence(triangular_lattice.hexagonal_register(5), Chadoq2) + seq.declare_channel("ryd", "rydberg_global") + seq.declare_channel("raman", "raman_local", initial_target="q0") + seq.phase_shift(1, "q0", **phase_kwargs) + seq.phase_shift_index(2, 1, **phase_kwargs) + seq.measure(**measure_kwargs) + + abstract = json.loads(seq.to_abstract_repr()) + validate_schema(abstract) + assert len(abstract["operations"]) == 3 + + assert abstract["operations"][0] == { + "op": "target", + "channel": "raman", + "target": 0, + } + + assert abstract["operations"][1] == { + "op": "phase_shift", + "basis": phase_kwargs.get("basis", "digital"), + "targets": [0], + "phi": 1, + } + assert abstract["operations"][2] == { + "op": "phase_shift", + "basis": phase_kwargs.get("basis", "digital"), + "targets": [1], + "phi": 2, + } + assert abstract["measurement"] == measure_kwargs.get( + "basis", "ground-rydberg" + ) + def _get_serialized_seq( operations: list[dict] = None, @@ -565,6 +605,17 @@ def _check_roundtrip(serialized_seq: dict[str, Any]): ) assert s == json.loads(rs) + # Remove the defaults and check it still works + for var in seq.declared_variables.values(): + s["variables"][var.name]["value"] = [var.dtype()] * var.size + for q in s["register"]: + q.pop("default_trap", None) + s["name"] = "pulser-exported" + + seq2 = Sequence.from_abstract_repr(json.dumps(s)) + rs_no_defaults = seq2.to_abstract_repr() + assert s == json.loads(rs_no_defaults) + # Needed to replace lambdas in the pytest.mark.parametrize calls (due to mypy) def _get_op(op: dict) -> Any: @@ -660,7 +711,8 @@ def test_deserialize_seq_with_mag_field(self): seq = Sequence.from_abstract_repr(json.dumps(s)) assert np.all(seq.magnetic_field == mag_field) - def test_deserialize_variables(self): + @pytest.mark.parametrize("without_default", [True, False]) + def test_deserialize_variables(self, without_default): s = _get_serialized_seq( variables={ "yolo": {"type": "int", "value": [42, 43, 44]}, @@ -669,6 +721,9 @@ def test_deserialize_variables(self): ) _check_roundtrip(s) seq = Sequence.from_abstract_repr(json.dumps(s)) + if without_default: + # Serialize and deserialize again, without the defaults + seq = Sequence.from_abstract_repr(seq.to_abstract_repr()) # Check variables assert len(seq.declared_variables) == len(s["variables"])