Skip to content

Commit

Permalink
fix: Replace dump interface
Browse files Browse the repository at this point in the history
  • Loading branch information
alecandido committed Aug 12, 2024
1 parent abbe12c commit f6b8dcd
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 122 deletions.
176 changes: 65 additions & 111 deletions src/qibolab/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,7 @@
SingleQubitNatives,
TwoQubitNatives,
)
from qibolab.platform.platform import (
InstrumentMap,
Platform,
QubitMap,
QubitPairMap,
Settings,
update_configs,
)
from qibolab.platform.platform import Platform, Settings, update_configs
from qibolab.pulses import PulseSequence
from qibolab.pulses.pulse import PulseLike
from qibolab.qubits import Qubit, QubitId, QubitPair, QubitPairId
Expand Down Expand Up @@ -56,6 +49,32 @@ def load(cls, raw: dict):
pairs = _load_two_qubit_natives(raw["two_qubit"], qubits)
return cls(qubits, couplers, pairs)

def dump(self) -> dict:
"""Serialize native gates section to dictionary.
It follows the runcard format, using qubit and pair objects.
"""
native_gates = {
"single_qubit": {
dump_qubit_name(q): _dump_natives(qubit.native_gates)
for q, qubit in self.single_qubit.items()
}
}

native_gates["coupler"] = {
dump_qubit_name(q): _dump_natives(qubit.native_gates)
for q, qubit in self.coupler.items()
}

native_gates["two_qubit"] = {}
for pair in self.two_qubit.values():
natives = _dump_natives(pair.native_gates)
if len(natives) > 0:
pair_name = f"{pair.qubit1}-{pair.qubit2}"
native_gates["two_qubit"][pair_name] = natives

return native_gates


@dataclass
class Runcard:
Expand All @@ -72,6 +91,40 @@ def load(cls, path: Path):
natives = NativeGates.load(d["native_gates"])
return cls(settings=settings, components=components, native_gates=natives)

@classmethod
def from_platform(cls, platform: Platform):
return cls(
settings=platform.settings,
components=platform.configs,
native_gates=NativeGates(
single_qubit=platform.qubits,
coupler=platform.couplers,
two_qubit=platform.pairs,
),
)

def dump(self, path: Path, updates: Optional[list[ConfigUpdate]] = None):
"""Platform serialization as runcard (json) and kernels (npz).
The file saved follows the format explained in :ref:`Using runcards <using_runcards>`.
The requested ``path`` is the folder where the json and npz will be dumped.
``updates`` is an optional list if updates for platform configs. Later entries in the list take precedence over earlier ones (if they happen to update the same thing).
"""
_dump_kernels(self, path=path)

configs = self.components.copy()
update_configs(configs, updates or [])

settings = {
"settings": asdict(self.settings),
"components": _dump_component_configs(configs),
"native_gates": self.native_gates.dump(),
}

(path / RUNCARD).write_text(json.dumps(settings, sort_keys=False, indent=4))


def _load_qubit_name(name: str) -> QubitId:
"""Convert qubit name from string to integer or string."""
Expand Down Expand Up @@ -159,60 +212,7 @@ def _dump_natives(natives: Union[SingleQubitNatives, TwoQubitNatives]):
return data


def dump_native_gates(
qubits: QubitMap, pairs: QubitPairMap, couplers: Optional[QubitMap] = None
) -> dict:
"""Dump native gates section to dictionary following the runcard format,
using qubit and pair objects."""
# single-qubit native gates
native_gates = {
"single_qubit": {
dump_qubit_name(q): _dump_natives(qubit.native_gates)
for q, qubit in qubits.items()
}
}

# couplers native gates
native_gates["coupler"] = {
dump_qubit_name(q): _dump_natives(qubit.native_gates)
for q, qubit in qubits.items()
}

# two-qubit native gates
native_gates["two_qubit"] = {}
for pair in pairs.values():
natives = _dump_natives(pair.native_gates)
if len(natives) > 0:
pair_name = f"{pair.qubit1}-{pair.qubit2}"
native_gates["two_qubit"][pair_name] = natives

return native_gates


def dump_instruments(instruments: InstrumentMap) -> dict:
"""Dump instrument settings to a dictionary following the runcard
format."""
# Qblox modules settings are dictionaries and not dataclasses
data = {}
for name, instrument in instruments.items():
try:
# TODO: Migrate all instruments to this approach
# (I think it is also useful for qblox)
settings = instrument.dump()
if len(settings) > 0:
data[name] = settings
except AttributeError:
settings = instrument.settings
if settings is not None:
if isinstance(settings, dict):
data[name] = settings
else:
data[name] = settings.dump()

return data


def dump_component_configs(component_configs) -> dict:
def _dump_component_configs(component_configs) -> dict:
"""Dump channel configs."""
components = {}
for name, cfg in component_configs.items():
Expand All @@ -222,37 +222,7 @@ def dump_component_configs(component_configs) -> dict:
return components


def dump_runcard(
platform: Platform, path: Path, updates: Optional[list[ConfigUpdate]] = None
):
"""Serializes the platform and saves it as a json runcard file.
The file saved follows the format explained in :ref:`Using runcards <using_runcards>`.
Args:
platform (qibolab.platform.Platform): The platform to be serialized.
path (pathlib.Path): Path that the json file will be saved.
updates: List if updates for platform configs.
Later entries in the list take precedence over earlier ones (if they happen to update the same thing).
"""

configs = platform.configs.copy()
update_configs(configs, updates or [])

settings = {
"settings": asdict(platform.settings),
"qubits": list(platform.qubits),
"components": dump_component_configs(configs),
}

settings["native_gates"] = dump_native_gates(
platform.qubits, platform.pairs, platform.couplers
)

(path / RUNCARD).write_text(json.dumps(settings, sort_keys=False, indent=4))


def dump_kernels(platform: Platform, path: Path):
def _dump_kernels(runcard: Runcard, path: Path):
"""Creates Kernels instance from platform and dumps as npz.
Args:
Expand All @@ -262,27 +232,11 @@ def dump_kernels(platform: Platform, path: Path):

# create kernels
kernels = Kernels()
for qubit in platform.qubits.values():
kernel = platform.configs[qubit.acquisition.name].kernel
for qubit in runcard.native_gates.single_qubit.values():
kernel = runcard.components[qubit.acquisition.name].kernel
if kernel is not None:
kernels[qubit.name] = kernel

# dump only if not None
if len(kernels) > 0:
kernels.dump(path)


def dump_platform(
platform: Platform, path: Path, updates: Optional[list[ConfigUpdate]] = None
):
"""Platform serialization as runcard (json) and kernels (npz).
Args:
platform (qibolab.platform.Platform): The platform to be serialized.
path (pathlib.Path): Path where json and npz will be dumped.
updates: List if updates for platform configs.
Later entries in the list take precedence over earlier ones (if they happen to update the same thing).
"""

dump_kernels(platform=platform, path=path)
dump_runcard(platform=platform, path=path, updates=updates)
16 changes: 5 additions & 11 deletions tests/test_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,7 @@
from qibolab.platform.platform import update_configs
from qibolab.pulses import Delay, Gaussian, Pulse, PulseSequence, Rectangular
from qibolab.qubits import Qubit, QubitPair
from qibolab.serialize import (
PLATFORM,
Runcard,
dump_kernels,
dump_platform,
dump_runcard,
)
from qibolab.serialize import PLATFORM, Runcard, _dump_kernels

from .conftest import find_instrument

Expand Down Expand Up @@ -150,7 +144,7 @@ def test_update_configs(platform):


def test_dump_runcard(platform, tmp_path):
dump_runcard(platform, tmp_path)
Runcard.from_platform(platform).dump(tmp_path)
final = Runcard.load(tmp_path)
if platform.name == "dummy":
target = Runcard.load(FOLDER)
Expand All @@ -170,7 +164,7 @@ def test_dump_runcard_with_updates(platform, tmp_path):
qubit.drive.name: {"frequency": frequency},
qubit.acquisition.name: {"smearing": smearing},
}
dump_runcard(platform, tmp_path, [update])
Runcard.from_platform(platform).dump(tmp_path, [update])
final = Runcard.load(tmp_path)
assert final.components[qubit.drive.name]["frequency"] == frequency
assert final.components[qubit.acquisition.name]["smearing"] == smearing
Expand All @@ -186,7 +180,7 @@ def test_kernels(tmp_path, has_kernels):
if isinstance(config, AcquisitionConfig):
platform.configs[name] = replace(config, kernel=np.random.rand(10))

dump_kernels(platform, tmp_path)
_dump_kernels(Runcard.from_platform(platform), tmp_path)

if has_kernels:
kernels = Kernels.load(tmp_path)
Expand All @@ -208,7 +202,7 @@ def test_dump_platform(tmp_path, has_kernels):
if isinstance(config, AcquisitionConfig):
platform.configs[name] = replace(config, kernel=np.random.rand(10))

dump_platform(platform, tmp_path)
Runcard.from_platform(platform).dump(tmp_path)

settings = Runcard.load(tmp_path).settings
if has_kernels:
Expand Down

0 comments on commit f6b8dcd

Please sign in to comment.