Skip to content

Commit

Permalink
Merge pull request #1005 from qiboteam/baking
Browse files Browse the repository at this point in the history
Add baking in QM driver
  • Loading branch information
stavros11 authored Aug 27, 2024
2 parents e7a8d18 + 60580ce commit d91498e
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 59 deletions.
8 changes: 5 additions & 3 deletions src/qibolab/instruments/qm/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,11 @@ def register_waveforms(
else:
qmpulse = QmAcquisition.from_pulse(pulse, element)
waveforms = waveforms_from_pulse(pulse)
modes = ["I"] if dc else ["I", "Q"]
for mode in modes:
self.waveforms[getattr(qmpulse.waveforms, mode)] = waveforms[mode]
if dc:
self.waveforms[qmpulse.waveforms["single"]] = waveforms["I"]
else:
for mode in ["I", "Q"]:
self.waveforms[getattr(qmpulse.waveforms, mode)] = waveforms[mode]
return qmpulse

def register_iq_pulse(self, element: str, pulse: Pulse):
Expand Down
24 changes: 19 additions & 5 deletions src/qibolab/instruments/qm/config/pulses.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@ def operation(pulse):
return str(hash(pulse))


def baked_duration(duration: int) -> int:
"""Calculate waveform length after pulse baking.
QM can only play pulses with length that is >16ns and multiple of
4ns. Waveforms that don't satisfy these constraints are padded with
zeros.
"""
return int(np.maximum((duration + 3) // 4 * 4, 16))


@dataclass(frozen=True)
class ConstantWaveform:
sample: float
Expand All @@ -47,9 +57,12 @@ class ArbitraryWaveform:
def from_pulse(cls, pulse: Pulse):
original_waveforms = pulse.envelopes(SAMPLING_RATE)
rotated_waveforms = rotate(original_waveforms, pulse.relative_phase)
new_duration = baked_duration(pulse.duration)
pad_len = new_duration - int(pulse.duration)
baked_waveforms = np.pad(rotated_waveforms, ((0, 0), (0, pad_len)))
return {
"I": cls(rotated_waveforms[0]),
"Q": cls(rotated_waveforms[1]),
"I": cls(baked_waveforms[0]),
"Q": cls(baked_waveforms[1]),
}


Expand All @@ -58,9 +71,10 @@ def from_pulse(cls, pulse: Pulse):

def waveforms_from_pulse(pulse: Pulse) -> Waveform:
"""Register QM waveforms for a given pulse."""
needs_baking = pulse.duration < 16 or pulse.duration % 4 != 0
wvtype = (
ConstantWaveform
if isinstance(pulse.envelope, Rectangular)
if isinstance(pulse.envelope, Rectangular) and not needs_baking
else ArbitraryWaveform
)
return wvtype.from_pulse(pulse)
Expand All @@ -87,15 +101,15 @@ class QmPulse:
def from_pulse(cls, pulse: Pulse):
op = operation(pulse)
return cls(
length=pulse.duration,
length=baked_duration(pulse.duration),
waveforms=Waveforms.from_op(op),
)

@classmethod
def from_dc_pulse(cls, pulse: Pulse):
op = operation(pulse)
return cls(
length=pulse.duration,
length=baked_duration(pulse.duration),
waveforms={"single": op},
)

Expand Down
30 changes: 5 additions & 25 deletions src/qibolab/instruments/qm/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,23 +71,6 @@ def declare_octaves(octaves, host, calibration_path=None):
return config


def find_baking_pulses(sweepers):
"""Find pulses that require baking because we are sweeping their duration.
Args:
sweepers (list): List of :class:`qibolab.sweeper.Sweeper` objects.
"""
to_bake = set()
for sweeper in sweepers:
values = sweeper.values
step = values[1] - values[0] if len(values) > 0 else values[0]
if sweeper.parameter is Parameter.duration and step % 4 != 0:
for pulse in sweeper.pulses:
to_bake.add(pulse.id)

return to_bake


def fetch_results(result, acquisitions):
"""Fetches results from an executed experiment.
Expand Down Expand Up @@ -299,14 +282,6 @@ def configure_channels(
def register_pulse(self, channel: Channel, pulse: Pulse) -> str:
"""Add pulse in the QM ``config`` and return corresponding
operation."""
# if (
# pulse.duration % 4 != 0
# or pulse.duration < 16
# or pulse.id in pulses_to_bake
# ):
# qmpulse = BakedPulse(pulse, element)
# qmpulse.bake(self.config, durations=[pulse.duration])
# else:
name = str(channel.name)
if isinstance(channel, DcChannel):
return self.config.register_dc_pulse(name, pulse)
Expand All @@ -322,6 +297,11 @@ def register_pulses(self, configs: dict[str, Config], sequence: PulseSequence):
acquisitions (dict): Map from measurement instructions to acquisition objects.
"""
for channel_id, pulse in sequence:
if hasattr(pulse, "duration") and not pulse.duration.is_integer():
raise ValueError(
f"Quantum Machines cannot play pulse with duration {pulse.duration}. "
"Only integer duration in ns is supported."
)
if isinstance(pulse, Pulse):
channel = self.channels[str(channel_id)].logical_channel
self.register_pulse(channel, pulse)
Expand Down
1 change: 1 addition & 0 deletions src/qibolab/instruments/qm/program/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class Parameters:
amplitude: Optional[_Variable] = None
phase: Optional[_Variable] = None
pulses: list[tuple[float, str]] = field(default_factory=list)
interpolated: bool = False


@dataclass
Expand Down
19 changes: 15 additions & 4 deletions src/qibolab/instruments/qm/program/instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from qibolab.components import Config
from qibolab.execution_parameters import AcquisitionType, ExecutionParameters
from qibolab.identifier import ChannelType
from qibolab.pulses import Align, Delay, Pulse, VirtualZ
from qibolab.sweeper import ParallelSweepers

Expand All @@ -18,17 +19,24 @@
def _delay(pulse: Delay, element: str, parameters: Parameters):
# TODO: How to play delays on multiple elements?
if parameters.duration is None:
duration = int(pulse.duration) // 4
duration = max(int(pulse.duration) // 4 + 1, 4)
qua.wait(duration, element)
elif parameters.interpolated:
duration = parameters.duration + 1
qua.wait(duration, element)
else:
duration = parameters.duration
qua.wait(duration + 1, element)
duration = parameters.duration / 4
with qua.if_(duration < 4):
qua.wait(4, element)
with qua.else_():
qua.wait(duration, element)


def _play_multiple_waveforms(element: str, parameters: Parameters):
"""Sweeping pulse duration using distinctly uploaded waveforms."""
with qua.switch_(parameters.duration, unsafe=True):
for value, sweep_op in parameters.pulses:
with qua.case_(value // 4):
with qua.case_(value):
qua.play(sweep_op, element)


Expand Down Expand Up @@ -68,6 +76,9 @@ def play(args: ExecutionArguments):
processed_aligns = set()

for channel_id, pulse in args.sequence:
if channel_id.channel_type is ChannelType.ACQUISITION:
continue

element = str(channel_id)
op = operation(pulse)
params = args.parameters[op]
Expand Down
40 changes: 18 additions & 22 deletions src/qibolab/instruments/qm/program/sweepers.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,22 +46,6 @@ def check_max_offset(offset: Optional[float], max_offset: float = MAX_OFFSET):
)


# def _update_baked_pulses(sweeper, qmsequence, config):
# """Updates baked pulse if duration sweeper is used."""
# qmpulse = qmsequence.pulse_to_qmpulse[sweeper.pulses[0].id]
# is_baked = isinstance(qmpulse, BakedPulse)
# for pulse in sweeper.pulses:
# qmpulse = qmsequence.pulse_to_qmpulse[pulse.id]
# if isinstance(qmpulse, BakedPulse):
# if not is_baked:
# raise_error(
# TypeError,
# "Duration sweeper cannot contain both baked and not baked pulses.",
# )
# values = np.array(sweeper.values).astype(int)
# qmpulse.bake(config, values)


def _frequency(
channels: list[Channel],
values: npt.NDArray,
Expand Down Expand Up @@ -100,9 +84,6 @@ def _amplitude(
raise_error(ValueError, "Amplitude sweep values are >2 which is not supported.")

for pulse in pulses:
# if isinstance(instruction, Bake):
# instructions.update_kwargs(instruction, amplitude=a)
# else:
args.parameters[operation(pulse)].amplitude = qua.amp(variable)


Expand Down Expand Up @@ -145,18 +126,34 @@ def _duration(
configs: dict[str, Config],
args: ExecutionArguments,
):
# TODO: Handle baked pulses
for pulse in pulses:
args.parameters[operation(pulse)].duration = variable


def _duration_interpolated(
pulses: list[Pulse],
values: npt.NDArray,
variable: _Variable,
configs: dict[str, Config],
args: ExecutionArguments,
):
for pulse in pulses:
params = args.parameters[operation(pulse)]
params.duration = variable
params.interpolated = True


def normalize_phase(values):
"""Normalize phase from [0, 2pi] to [0, 1]."""
return values / (2 * np.pi)


def normalize_duration(values):
"""Convert duration from ns to clock cycles (clock cycle = 4ns)."""
if any(values < 16) and not all(values % 4 == 0):
raise ValueError(
"Cannot use interpolated duration sweeper for durations that are not multiple of 4ns or are less than 16ns. Please use normal duration sweeper."
)
return (values // 4).astype(int)


Expand All @@ -168,7 +165,6 @@ def normalize_duration(values):

NORMALIZERS = {
Parameter.relative_phase: normalize_phase,
Parameter.duration: normalize_duration,
Parameter.duration_interpolated: normalize_duration,
}
"""Functions to normalize sweeper values.
Expand All @@ -180,7 +176,7 @@ def normalize_duration(values):
Parameter.frequency: _frequency,
Parameter.amplitude: _amplitude,
Parameter.duration: _duration,
Parameter.duration_interpolated: _duration,
Parameter.duration_interpolated: _duration_interpolated,
Parameter.relative_phase: _relative_phase,
Parameter.bias: _bias,
}
Expand Down

0 comments on commit d91498e

Please sign in to comment.