Skip to content

Commit

Permalink
Refactoring and feature isolation to ease subclassing (#464)
Browse files Browse the repository at this point in the history
* Isolating modulation method in Channel

* Isolating standard samples

* Isolating figure creationg in Sequence.draw()

* Add option to consider the phase of detuned delays

* Bump version to 0.9.2

* Import sorting

* Restrict black version

* Adjust padding on `keep_ends=True`

* Move comment
  • Loading branch information
HGSilveri authored Feb 16, 2023
1 parent 74d809c commit 2decdff
Show file tree
Hide file tree
Showing 8 changed files with 81 additions and 39 deletions.
2 changes: 1 addition & 1 deletion VERSION.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.9.1
0.9.2
3 changes: 1 addition & 2 deletions dev_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# tests
black
black[jupyter]
black[jupyter] < 23.1
flake8
flake8-docstrings
isort
Expand Down
52 changes: 40 additions & 12 deletions pulser-core/pulser/channels/base_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,15 @@ def validate_pulse(self, pulse: Pulse) -> None:
"allowed for the chosen channel."
)

@property
def _modulation_padding(self) -> int:
"""The padding added to the input signals before modulation.
Defined in number of samples to pad before and after signal
(i.e. the signal is extended by 2*_modulation_padding).
"""
return self.rise_time

def modulate(
self,
input_samples: np.ndarray,
Expand All @@ -380,7 +389,7 @@ def modulate(
raise TypeError(f"The channel {self} does not have an EOM.")
eom_config = cast(BaseEOM, self.eom_config)
mod_bandwidth = eom_config.mod_bandwidth
rise_time = eom_config.rise_time
mod_padding = eom_config.rise_time

elif not self.mod_bandwidth:
warnings.warn(
Expand All @@ -391,22 +400,41 @@ def modulate(
return input_samples
else:
mod_bandwidth = self.mod_bandwidth
rise_time = self.rise_time
mod_padding = self._modulation_padding

# The cutoff frequency (fc) and the modulation transfer function
# are defined in https://tinyurl.com/bdeumc8k
fc = mod_bandwidth * 1e-3 / np.sqrt(np.log(2))
if keep_ends:
samples = np.pad(input_samples, 2 * rise_time, mode="edge")
samples = np.pad(
input_samples, mod_padding + self.rise_time, mode="edge"
)
else:
samples = np.pad(input_samples, rise_time)
freqs = fftfreq(samples.size)
modulation = np.exp(-(freqs**2) / fc**2)
mod_samples = ifft(fft(samples) * modulation).real
samples = np.pad(input_samples, mod_padding)
mod_samples = self.apply_modulation(samples, mod_bandwidth)
if keep_ends:
# Cut off the extra ends
return cast(np.ndarray, mod_samples[rise_time:-rise_time])
return cast(np.ndarray, mod_samples)
return mod_samples[self.rise_time : -self.rise_time]
return mod_samples

@staticmethod
def apply_modulation(
input_samples: np.ndarray, mod_bandwidth: float
) -> np.ndarray:
"""Applies the modulation transfer fuction to the input samples.
Note:
This is strictly the application of the modulation transfer
function. The samples should be padded beforehand.
Args:
input_samples: The samples to modulate.
mod_bandwidth: The modulation bandwidth at -3dB (50% reduction),
in MHz.
"""
# The cutoff frequency (fc) and the modulation transfer function
# are defined in https://tinyurl.com/bdeumc8k
fc = mod_bandwidth * 1e-3 / np.sqrt(np.log(2))
freqs = fftfreq(input_samples.size)
modulation = np.exp(-(freqs**2) / fc**2)
return cast(np.ndarray, ifft(fft(input_samples) * modulation).real)

def calc_modulation_buffer(
self,
Expand Down
4 changes: 3 additions & 1 deletion pulser-core/pulser/sampler/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
if TYPE_CHECKING:
from pulser import Sequence

IGNORE_DETUNED_DELAY_PHASE = True


def sample(
seq: Sequence,
Expand All @@ -27,7 +29,7 @@ def sample(

samples_list = []
for ch_schedule in seq._schedule.values():
samples = ch_schedule.get_samples()
samples = ch_schedule.get_samples(IGNORE_DETUNED_DELAY_PHASE)
if extended_duration:
samples = samples.extend_duration(extended_duration)
if modulation:
Expand Down
30 changes: 18 additions & 12 deletions pulser-core/pulser/sampler/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,14 +145,27 @@ def is_empty(self) -> bool:
"""
return np.count_nonzero(self.amp) + np.count_nonzero(self.det) == 0

def _generate_std_samples(self) -> ChannelSamples:
new_samples = {
key: getattr(self, key).copy() for key in ("amp", "det")
}
for block in self.eom_blocks:
region = slice(block.ti, block.tf)
new_samples["amp"][region] = 0
# For modulation purposes, the detuning on the standard
# samples is kept at 'detuning_off', which permits a smooth
# transition to/from the EOM modulated samples
new_samples["det"][region] = block.detuning_off

return replace(self, **new_samples)

def modulate(
self, channel_obj: Channel, max_duration: Optional[int] = None
) -> ChannelSamples:
"""Modulates the samples for a given channel.
It assumes that the phase starts at its initial value and is kept at
its final value. The same could potentially be done for the detuning,
but it's not as safe of an assumption so it's not done for now.
It assumes that the detuning and phase start at their initial values
and are kept at their final values.
Args:
channel_obj: The channel object for which to modulate the samples.
Expand All @@ -173,14 +186,12 @@ def masked(samples: np.ndarray, mask: np.ndarray) -> np.ndarray:

new_samples: dict[str, np.ndarray] = {}

std_samples = {
key: getattr(self, key).copy() for key in ("amp", "det")
}
eom_samples = {
key: getattr(self, key).copy() for key in ("amp", "det")
}

if self.eom_blocks:
std_samples = self._generate_std_samples()
# Note: self.duration already includes the fall time
eom_mask = np.zeros(self.duration, dtype=bool)
# Extension of the EOM mask outside of the EOM interval
Expand All @@ -190,11 +201,6 @@ def masked(samples: np.ndarray, mask: np.ndarray) -> np.ndarray:
# If block.tf is None, uses the full duration as the tf
end = block.tf or self.duration
eom_mask[block.ti : end] = True
std_samples["amp"][block.ti : end] = 0
# For modulation purposes, the detuning on the standard
# samples is kept at 'detuning_off', which permits a smooth
# transition to/from the EOM modulated samples
std_samples["det"][block.ti : end] = block.detuning_off
# Extends EOM masks to include fall time
ext_end = end + eom_fall_time
eom_mask_ext[end:ext_end] = True
Expand All @@ -215,7 +221,7 @@ def masked(samples: np.ndarray, mask: np.ndarray) -> np.ndarray:
# we mask them to include only the parts outside the EOM mask
# This ensures smooth transitions between EOM and STD samples
modulated_std = channel_obj.modulate(
std_samples[key], keep_ends=key == "det"
getattr(std_samples, key), keep_ends=key == "det"
)
std = masked(modulated_std, ~eom_mask)

Expand Down
17 changes: 11 additions & 6 deletions pulser-core/pulser/sequence/_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,9 @@ def adjust_duration(self, duration: int) -> int:
max(duration, self.channel_obj.min_duration)
)

def get_samples(self) -> ChannelSamples:
def get_samples(
self, ignore_detuned_delay_phase: bool = True
) -> ChannelSamples:
"""Returns the samples of the channel."""
# Keep only pulse slots
channel_slots = [s for s in self.slots if isinstance(s.type, Pulse)]
Expand All @@ -155,16 +157,19 @@ def get_samples(self) -> ChannelSamples:
)
slots.append(_TargetSlot(s.ti, tf, s.targets))

# The phase of detuned delays is not considered
if self.is_detuned_delay(pulse):
if ignore_detuned_delay_phase and self.is_detuned_delay(pulse):
# The phase of detuned delays is not considered
continue

ph_jump_t = self.channel_obj.phase_jump_time
for last_pulse_ind in range(ind - 1, -1, -1): # From ind-1 to 0
last_pulse_slot = channel_slots[last_pulse_ind]
# Skips over detuned delay pulses
if not self.is_detuned_delay(
cast(Pulse, last_pulse_slot.type)
if not (
ignore_detuned_delay_phase
and self.is_detuned_delay(
cast(Pulse, last_pulse_slot.type)
)
):
# Accounts for when pulse is added with 'no-delay'
# i.e. there is no phase_jump_time in between a phase jump
Expand All @@ -173,7 +178,7 @@ def get_samples(self) -> ChannelSamples:
else:
t_start = 0
# Overrides all values from t_start on. The next pulses will do
# the same, so the last phase is automatically kept till the endm
# the same, so the last phase is automatically kept till the end
phase[t_start:] = pulse.phase

return ChannelSamples(amp, det, phase, slots, self.eom_blocks)
Expand Down
2 changes: 1 addition & 1 deletion pulser-core/pulser/sequence/_seq_drawer.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def draw_sequence(
draw_input: bool = True,
draw_modulation: bool = False,
draw_phase_curve: bool = False,
) -> tuple[Figure, Figure]:
) -> tuple[Figure | None, Figure]:
"""Draws the entire sequence.
Args:
Expand Down
10 changes: 6 additions & 4 deletions pulser-core/pulser/sequence/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from pulser.sequence._basis_ref import _QubitRef
from pulser.sequence._call import _Call
from pulser.sequence._schedule import _ChannelSchedule, _Schedule, _TimeSlot
from pulser.sequence._seq_drawer import draw_sequence
from pulser.sequence._seq_drawer import Figure, draw_sequence
from pulser.sequence._seq_str import seq_to_str

if version_info[:2] >= (3, 8): # pragma: no cover
Expand Down Expand Up @@ -1328,8 +1328,7 @@ def draw(
"Can't draw the register for a sequence without a defined "
"register."
)
fig_reg, fig = draw_sequence(
self,
fig_reg, fig = self._plot(
draw_phase_area=draw_phase_area,
draw_interp_pts=draw_interp_pts,
draw_phase_shifts=draw_phase_shifts,
Expand All @@ -1338,14 +1337,17 @@ def draw(
draw_modulation="output" in mode,
draw_phase_curve=draw_phase_curve,
)
if fig_name is not None and draw_register:
if fig_name is not None and fig_reg is not None:
name, ext = os.path.splitext(fig_name)
fig.savefig(name + "_pulses" + ext, **kwargs_savefig)
fig_reg.savefig(name + "_register" + ext, **kwargs_savefig)
elif fig_name:
fig.savefig(fig_name, **kwargs_savefig)
plt.show()

def _plot(self, **draw_options: bool) -> tuple[Figure | None, Figure]:
return draw_sequence(self, **draw_options)

def _add(
self,
pulse: Union[Pulse, Parametrized],
Expand Down

0 comments on commit 2decdff

Please sign in to comment.