From 4780409ec679401f81fc14030bd6978d814547f3 Mon Sep 17 00:00:00 2001 From: Alessandro Candido Date: Mon, 12 Aug 2024 20:47:44 +0200 Subject: [PATCH] feat: Start adapting existing classes to automated serialization --- src/qibolab/native.py | 50 ++++++++++++++-------------------- src/qibolab/pulses/sequence.py | 16 ++++++++++- src/qibolab/qubits.py | 17 +++++++----- src/qibolab/serialize.py | 20 ++------------ 4 files changed, 48 insertions(+), 55 deletions(-) diff --git a/src/qibolab/native.py b/src/qibolab/native.py index 753c569108..ea41e9dd95 100644 --- a/src/qibolab/native.py +++ b/src/qibolab/native.py @@ -1,10 +1,9 @@ -from dataclasses import dataclass, field, fields -from typing import Optional +from typing import Annotated, Optional import numpy as np from .pulses import Drag, Gaussian, Pulse, PulseSequence -from .serialize_ import replace +from .serialize_ import Model, replace def _normalize_angles(theta, phi): @@ -15,7 +14,7 @@ def _normalize_angles(theta, phi): return theta, phi -class RxyFactory: +class RxyFactory(PulseSequence): """Factory for pulse sequences that generate single-qubit rotations around an axis in xy plane. @@ -27,17 +26,19 @@ class RxyFactory: sequence: The base sequence for the factory. """ - def __init__(self, sequence: PulseSequence): + @classmethod + def validate(cls, value): + sequence = PulseSequence(value) if len(sequence.channels) != 1: raise ValueError( f"Incompatible number of channels: {len(sequence.channels)}. " - f"{self.__class__} expects a sequence on exactly one channel." + f"{cls} expects a sequence on exactly one channel." ) if len(sequence) != 1: raise ValueError( f"Incompatible number of pulses: {len(sequence)}. " - f"{self.__class__} expects a sequence with exactly one pulse." + f"{cls} expects a sequence with exactly one pulse." ) pulse = sequence[0][1] @@ -46,10 +47,10 @@ def __init__(self, sequence: PulseSequence): if not isinstance(pulse.envelope, expected_envelopes): raise ValueError( f"Incompatible pulse envelope: {pulse.envelope.__class__}. " - f"{self.__class__} expects {expected_envelopes} envelope." + f"{cls} expects {expected_envelopes} envelope." ) - self._seq = sequence + return cls(value) def create_sequence(self, theta: float = np.pi, phi: float = 0.0) -> PulseSequence: """Create a sequence for single-qubit rotation. @@ -59,7 +60,7 @@ def create_sequence(self, theta: float = np.pi, phi: float = 0.0) -> PulseSequen phi: the angle that rotation axis forms with x axis. """ theta, phi = _normalize_angles(theta, phi) - ch, pulse = self._seq[0] + ch, pulse = self[0] assert isinstance(pulse, Pulse) new_amplitude = pulse.amplitude * theta / np.pi return PulseSequence( @@ -67,18 +68,14 @@ def create_sequence(self, theta: float = np.pi, phi: float = 0.0) -> PulseSequen ) -class FixedSequenceFactory: +class FixedSequenceFactory(PulseSequence): """Simple factory for a fixed arbitrary sequence.""" - def __init__(self, sequence: PulseSequence): - self._seq = sequence - def create_sequence(self) -> PulseSequence: - return self._seq.copy() + return self.copy() -@dataclass -class SingleQubitNatives: +class SingleQubitNatives(Model): """Container with the native single-qubit gates acting on a specific qubit.""" @@ -92,26 +89,19 @@ class SingleQubitNatives: """Pulse to activate coupler.""" -@dataclass -class TwoQubitNatives: +class TwoQubitNatives(Model): """Container with the native two-qubit gates acting on a specific pair of qubits.""" - CZ: Optional[FixedSequenceFactory] = field( - default=None, metadata={"symmetric": True} - ) - CNOT: Optional[FixedSequenceFactory] = field( - default=None, metadata={"symmetric": False} - ) - iSWAP: Optional[FixedSequenceFactory] = field( - default=None, metadata={"symmetric": True} - ) + CZ: Annotated[Optional[FixedSequenceFactory], {"symmetric": True}] = None + CNOT: Annotated[Optional[FixedSequenceFactory], {"symmetric": True}] = None + iSWAP: Annotated[Optional[FixedSequenceFactory], {"symmetric": True}] = None @property def symmetric(self): """Check if the defined two-qubit gates are symmetric between target and control qubits.""" return all( - fld.metadata["symmetric"] or getattr(self, fld.name) is None - for fld in fields(self) + info.metadata[0]["symmetric"] or getattr(self, fld) is None + for fld, info in self.model_fields.items() ) diff --git a/src/qibolab/pulses/sequence.py b/src/qibolab/pulses/sequence.py index 6ad5b4144a..331583e4ae 100644 --- a/src/qibolab/pulses/sequence.py +++ b/src/qibolab/pulses/sequence.py @@ -1,7 +1,10 @@ """PulseSequence class.""" from collections import UserList -from collections.abc import Iterable +from collections.abc import Callable, Iterable +from typing import Any + +from pydantic_core import core_schema from qibolab.components import ChannelId @@ -22,6 +25,17 @@ class PulseSequence(UserList[_Element]): the action, and the channel on which it should be performed. """ + @classmethod + def __get_pydantic_core_schema__( + cls, source_type: Any, handler: Callable[[Any], core_schema.CoreSchema] + ) -> core_schema.CoreSchema: + schema = handler(list[_Element]) + return core_schema.no_info_after_validator_function(cls.validate, schema) + + @classmethod + def validate(cls, value): + return cls(value) + @property def duration(self) -> float: """Duration of the entire sequence.""" diff --git a/src/qibolab/qubits.py b/src/qibolab/qubits.py index 645a1f3c1a..251a98b0c9 100644 --- a/src/qibolab/qubits.py +++ b/src/qibolab/qubits.py @@ -1,10 +1,13 @@ -from dataclasses import dataclass, field, fields -from typing import Optional, Union +from dataclasses import field, fields +from typing import Annotated, Optional, Union + +from pydantic import ConfigDict, Field from qibolab.components import AcquireChannel, DcChannel, IqChannel from qibolab.native import SingleQubitNatives, TwoQubitNatives +from qibolab.serialize_ import Model -QubitId = Union[str, int] +QubitId = Annotated[Union[str, int], Field(union_mode="left_to_right")] """Type for qubit names.""" CHANNEL_NAMES = ("probe", "acquisition", "drive", "drive12", "drive_cross", "flux") @@ -14,8 +17,7 @@ """ -@dataclass -class Qubit: +class Qubit(Model): """Representation of a physical qubit. Qubit objects are instantiated by :class:`qibolab.platforms.platform.Platform` @@ -31,6 +33,8 @@ class Qubit: send flux pulses to the qubit. """ + model_config = ConfigDict(frozen=False) + name: QubitId native_gates: SingleQubitNatives = field(default_factory=SingleQubitNatives) @@ -70,8 +74,7 @@ def mixer_frequencies(self): """Type for holding ``QubitPair``s in the ``platform.pairs`` dictionary.""" -@dataclass -class QubitPair: +class QubitPair(Model): """Data structure for holding the native two-qubit gates acting on a pair of qubits. diff --git a/src/qibolab/serialize.py b/src/qibolab/serialize.py index b85e4b8a62..55dbcc59eb 100644 --- a/src/qibolab/serialize.py +++ b/src/qibolab/serialize.py @@ -15,12 +15,7 @@ from qibolab.components import Config from qibolab.execution_parameters import ConfigUpdate, ExecutionParameters from qibolab.kernels import Kernels -from qibolab.native import ( - FixedSequenceFactory, - RxyFactory, - SingleQubitNatives, - TwoQubitNatives, -) +from qibolab.native import FixedSequenceFactory, SingleQubitNatives, TwoQubitNatives from qibolab.pulses import PulseSequence from qibolab.pulses.pulse import PulseLike from qibolab.qubits import Qubit, QubitId, QubitPair, QubitPairId @@ -180,17 +175,8 @@ def _load_single_qubit_natives(gates: dict) -> dict[QubitId, Qubit]: qubits = {} for q, gatedict in gates.items(): name = _load_qubit_name(q) - native_gates = SingleQubitNatives( - **{ - gate_name: ( - RxyFactory(_load_sequence(raw_sequence)) - if gate_name == "RX" - else FixedSequenceFactory(_load_sequence(raw_sequence)) - ) - for gate_name, raw_sequence in gatedict.items() - } - ) - qubits[name] = Qubit(_load_qubit_name(q), native_gates=native_gates) + native_gates = SingleQubitNatives(**gatedict) + qubits[name] = Qubit(name=name, native_gates=native_gates) return qubits