Skip to content

Commit

Permalink
feat: Start adapting existing classes to automated serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
alecandido committed Aug 12, 2024
1 parent 5efe7ba commit 4780409
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 55 deletions.
50 changes: 20 additions & 30 deletions src/qibolab/native.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from dataclasses import dataclass, field, fields
from typing import Optional
from typing import Annotated, Optional

import numpy as np

from .pulses import Drag, Gaussian, Pulse, PulseSequence
from .serialize_ import replace
from .serialize_ import Model, replace


def _normalize_angles(theta, phi):
Expand All @@ -15,7 +14,7 @@ def _normalize_angles(theta, phi):
return theta, phi


class RxyFactory:
class RxyFactory(PulseSequence):
"""Factory for pulse sequences that generate single-qubit rotations around
an axis in xy plane.
Expand All @@ -27,17 +26,19 @@ class RxyFactory:
sequence: The base sequence for the factory.
"""

def __init__(self, sequence: PulseSequence):
@classmethod
def validate(cls, value):
sequence = PulseSequence(value)
if len(sequence.channels) != 1:
raise ValueError(
f"Incompatible number of channels: {len(sequence.channels)}. "
f"{self.__class__} expects a sequence on exactly one channel."
f"{cls} expects a sequence on exactly one channel."
)

if len(sequence) != 1:
raise ValueError(
f"Incompatible number of pulses: {len(sequence)}. "
f"{self.__class__} expects a sequence with exactly one pulse."
f"{cls} expects a sequence with exactly one pulse."
)

pulse = sequence[0][1]
Expand All @@ -46,10 +47,10 @@ def __init__(self, sequence: PulseSequence):
if not isinstance(pulse.envelope, expected_envelopes):
raise ValueError(
f"Incompatible pulse envelope: {pulse.envelope.__class__}. "
f"{self.__class__} expects {expected_envelopes} envelope."
f"{cls} expects {expected_envelopes} envelope."
)

self._seq = sequence
return cls(value)

def create_sequence(self, theta: float = np.pi, phi: float = 0.0) -> PulseSequence:
"""Create a sequence for single-qubit rotation.
Expand All @@ -59,26 +60,22 @@ def create_sequence(self, theta: float = np.pi, phi: float = 0.0) -> PulseSequen
phi: the angle that rotation axis forms with x axis.
"""
theta, phi = _normalize_angles(theta, phi)
ch, pulse = self._seq[0]
ch, pulse = self[0]
assert isinstance(pulse, Pulse)
new_amplitude = pulse.amplitude * theta / np.pi
return PulseSequence(
[(ch, replace(pulse, amplitude=new_amplitude, relative_phase=phi))]
)


class FixedSequenceFactory:
class FixedSequenceFactory(PulseSequence):
"""Simple factory for a fixed arbitrary sequence."""

def __init__(self, sequence: PulseSequence):
self._seq = sequence

def create_sequence(self) -> PulseSequence:
return self._seq.copy()
return self.copy()


@dataclass
class SingleQubitNatives:
class SingleQubitNatives(Model):
"""Container with the native single-qubit gates acting on a specific
qubit."""

Expand All @@ -92,26 +89,19 @@ class SingleQubitNatives:
"""Pulse to activate coupler."""


@dataclass
class TwoQubitNatives:
class TwoQubitNatives(Model):
"""Container with the native two-qubit gates acting on a specific pair of
qubits."""

CZ: Optional[FixedSequenceFactory] = field(
default=None, metadata={"symmetric": True}
)
CNOT: Optional[FixedSequenceFactory] = field(
default=None, metadata={"symmetric": False}
)
iSWAP: Optional[FixedSequenceFactory] = field(
default=None, metadata={"symmetric": True}
)
CZ: Annotated[Optional[FixedSequenceFactory], {"symmetric": True}] = None
CNOT: Annotated[Optional[FixedSequenceFactory], {"symmetric": True}] = None
iSWAP: Annotated[Optional[FixedSequenceFactory], {"symmetric": True}] = None

@property
def symmetric(self):
"""Check if the defined two-qubit gates are symmetric between target
and control qubits."""
return all(
fld.metadata["symmetric"] or getattr(self, fld.name) is None
for fld in fields(self)
info.metadata[0]["symmetric"] or getattr(self, fld) is None
for fld, info in self.model_fields.items()
)
16 changes: 15 additions & 1 deletion src/qibolab/pulses/sequence.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""PulseSequence class."""

from collections import UserList
from collections.abc import Iterable
from collections.abc import Callable, Iterable
from typing import Any

from pydantic_core import core_schema

from qibolab.components import ChannelId

Expand All @@ -22,6 +25,17 @@ class PulseSequence(UserList[_Element]):
the action, and the channel on which it should be performed.
"""

@classmethod
def __get_pydantic_core_schema__(
cls, source_type: Any, handler: Callable[[Any], core_schema.CoreSchema]
) -> core_schema.CoreSchema:
schema = handler(list[_Element])
return core_schema.no_info_after_validator_function(cls.validate, schema)

@classmethod
def validate(cls, value):
return cls(value)

@property
def duration(self) -> float:
"""Duration of the entire sequence."""
Expand Down
17 changes: 10 additions & 7 deletions src/qibolab/qubits.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from dataclasses import dataclass, field, fields
from typing import Optional, Union
from dataclasses import field, fields
from typing import Annotated, Optional, Union

from pydantic import ConfigDict, Field

from qibolab.components import AcquireChannel, DcChannel, IqChannel
from qibolab.native import SingleQubitNatives, TwoQubitNatives
from qibolab.serialize_ import Model

QubitId = Union[str, int]
QubitId = Annotated[Union[str, int], Field(union_mode="left_to_right")]
"""Type for qubit names."""

CHANNEL_NAMES = ("probe", "acquisition", "drive", "drive12", "drive_cross", "flux")
Expand All @@ -14,8 +17,7 @@
"""


@dataclass
class Qubit:
class Qubit(Model):
"""Representation of a physical qubit.
Qubit objects are instantiated by :class:`qibolab.platforms.platform.Platform`
Expand All @@ -31,6 +33,8 @@ class Qubit:
send flux pulses to the qubit.
"""

model_config = ConfigDict(frozen=False)

name: QubitId

native_gates: SingleQubitNatives = field(default_factory=SingleQubitNatives)
Expand Down Expand Up @@ -70,8 +74,7 @@ def mixer_frequencies(self):
"""Type for holding ``QubitPair``s in the ``platform.pairs`` dictionary."""


@dataclass
class QubitPair:
class QubitPair(Model):
"""Data structure for holding the native two-qubit gates acting on a pair
of qubits.
Expand Down
20 changes: 3 additions & 17 deletions src/qibolab/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,7 @@
from qibolab.components import Config
from qibolab.execution_parameters import ConfigUpdate, ExecutionParameters
from qibolab.kernels import Kernels
from qibolab.native import (
FixedSequenceFactory,
RxyFactory,
SingleQubitNatives,
TwoQubitNatives,
)
from qibolab.native import FixedSequenceFactory, SingleQubitNatives, TwoQubitNatives
from qibolab.pulses import PulseSequence
from qibolab.pulses.pulse import PulseLike
from qibolab.qubits import Qubit, QubitId, QubitPair, QubitPairId
Expand Down Expand Up @@ -180,17 +175,8 @@ def _load_single_qubit_natives(gates: dict) -> dict[QubitId, Qubit]:
qubits = {}
for q, gatedict in gates.items():
name = _load_qubit_name(q)
native_gates = SingleQubitNatives(
**{
gate_name: (
RxyFactory(_load_sequence(raw_sequence))
if gate_name == "RX"
else FixedSequenceFactory(_load_sequence(raw_sequence))
)
for gate_name, raw_sequence in gatedict.items()
}
)
qubits[name] = Qubit(_load_qubit_name(q), native_gates=native_gates)
native_gates = SingleQubitNatives(**gatedict)
qubits[name] = Qubit(name=name, native_gates=native_gates)
return qubits


Expand Down

0 comments on commit 4780409

Please sign in to comment.