Skip to content

Commit

Permalink
templatize gate
Browse files Browse the repository at this point in the history
  • Loading branch information
gandalfr-KY committed Jan 17, 2025
1 parent 120de54 commit 0989281
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 117 deletions.
2 changes: 1 addition & 1 deletion exe/main.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include <cstdint>
#include <iostream>
#include <scaluq/all.hpp>
#include <scaluq/gate/gate.hpp>

using namespace scaluq;
using namespace nlohmann;
Expand Down
213 changes: 107 additions & 106 deletions include/scaluq/gate/gate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,64 +9,64 @@ namespace scaluq {
namespace internal {
// forward declarations

template <std::floating_point Fp>
template <std::floating_point Fp, ExecutionSpace sp>
class GateBase;

template <std::floating_point Fp>
template <std::floating_point Fp, ExecutionSpace sp>
class IGateImpl;
template <std::floating_point Fp>
template <std::floating_point Fp, ExecutionSpace sp>
class GlobalPhaseGateImpl;
template <std::floating_point Fp>
template <std::floating_point Fp, ExecutionSpace sp>
class XGateImpl;
template <std::floating_point Fp>
template <std::floating_point Fp, ExecutionSpace sp>
class YGateImpl;
template <std::floating_point Fp>
template <std::floating_point Fp, ExecutionSpace sp>
class ZGateImpl;
template <std::floating_point Fp>
template <std::floating_point Fp, ExecutionSpace sp>
class HGateImpl;
template <std::floating_point Fp>
template <std::floating_point Fp, ExecutionSpace sp>
class SGateImpl;
template <std::floating_point Fp>
template <std::floating_point Fp, ExecutionSpace sp>
class SdagGateImpl;
template <std::floating_point Fp>
template <std::floating_point Fp, ExecutionSpace sp>
class TGateImpl;
template <std::floating_point Fp>
template <std::floating_point Fp, ExecutionSpace sp>
class TdagGateImpl;
template <std::floating_point Fp>
template <std::floating_point Fp, ExecutionSpace sp>
class SqrtXGateImpl;
template <std::floating_point Fp>
template <std::floating_point Fp, ExecutionSpace sp>
class SqrtXdagGateImpl;
template <std::floating_point Fp>
template <std::floating_point Fp, ExecutionSpace sp>
class SqrtYGateImpl;
template <std::floating_point Fp>
template <std::floating_point Fp, ExecutionSpace sp>
class SqrtYdagGateImpl;
template <std::floating_point Fp>
template <std::floating_point Fp, ExecutionSpace sp>
class P0GateImpl;
template <std::floating_point Fp>
template <std::floating_point Fp, ExecutionSpace sp>
class P1GateImpl;
template <std::floating_point Fp>
template <std::floating_point Fp, ExecutionSpace sp>
class RXGateImpl;
template <std::floating_point Fp>
template <std::floating_point Fp, ExecutionSpace sp>
class RYGateImpl;
template <std::floating_point Fp>
template <std::floating_point Fp, ExecutionSpace sp>
class RZGateImpl;
template <std::floating_point Fp>
template <std::floating_point Fp, ExecutionSpace sp>
class U1GateImpl;
template <std::floating_point Fp>
template <std::floating_point Fp, ExecutionSpace sp>
class U2GateImpl;
template <std::floating_point Fp>
template <std::floating_point Fp, ExecutionSpace sp>
class U3GateImpl;
template <std::floating_point Fp>
template <std::floating_point Fp, ExecutionSpace sp>
class SwapGateImpl;
template <std::floating_point Fp>
template <std::floating_point Fp, ExecutionSpace sp>
class PauliGateImpl;
template <std::floating_point Fp>
template <std::floating_point Fp, ExecutionSpace sp>
class PauliRotationGateImpl;
template <std::floating_point Fp>
template <std::floating_point Fp, ExecutionSpace sp>
class ProbablisticGateImpl;
template <std::floating_point Fp>
template <std::floating_point Fp, ExecutionSpace sp>
class SparseMatrixGateImpl;
template <std::floating_point Fp>
template <std::floating_point Fp, ExecutionSpace sp>
class DenseMatrixGateImpl;

} // namespace internal
Expand Down Expand Up @@ -103,83 +103,83 @@ enum class GateType {
Probablistic
};

template <typename T, std::floating_point S>
template <typename T, std::floating_point Fp, ExecutionSpace Sp>
constexpr GateType get_gate_type() {
using TWithoutConst = std::remove_cv_t<T>;
if constexpr (std::is_same_v<TWithoutConst, internal::GateBase<S>>)
if constexpr (std::is_same_v<TWithoutConst, internal::GateBase<Fp, Sp>>)
return GateType::Unknown;
else if constexpr (std::is_same_v<TWithoutConst, internal::IGateImpl<S>>)
else if constexpr (std::is_same_v<TWithoutConst, internal::IGateImpl<Fp, Sp>>)
return GateType::I;
else if constexpr (std::is_same_v<TWithoutConst, internal::GlobalPhaseGateImpl<S>>)
else if constexpr (std::is_same_v<TWithoutConst, internal::GlobalPhaseGateImpl<Fp, Sp>>)
return GateType::GlobalPhase;
else if constexpr (std::is_same_v<TWithoutConst, internal::XGateImpl<S>>)
else if constexpr (std::is_same_v<TWithoutConst, internal::XGateImpl<Fp, Sp>>)
return GateType::X;
else if constexpr (std::is_same_v<TWithoutConst, internal::YGateImpl<S>>)
else if constexpr (std::is_same_v<TWithoutConst, internal::YGateImpl<Fp, Sp>>)
return GateType::Y;
else if constexpr (std::is_same_v<TWithoutConst, internal::ZGateImpl<S>>)
else if constexpr (std::is_same_v<TWithoutConst, internal::ZGateImpl<Fp, Sp>>)
return GateType::Z;
else if constexpr (std::is_same_v<TWithoutConst, internal::HGateImpl<S>>)
else if constexpr (std::is_same_v<TWithoutConst, internal::HGateImpl<Fp, Sp>>)
return GateType::H;
else if constexpr (std::is_same_v<TWithoutConst, internal::SGateImpl<S>>)
else if constexpr (std::is_same_v<TWithoutConst, internal::SGateImpl<Fp, Sp>>)
return GateType::S;
else if constexpr (std::is_same_v<TWithoutConst, internal::SdagGateImpl<S>>)
else if constexpr (std::is_same_v<TWithoutConst, internal::SdagGateImpl<Fp, Sp>>)
return GateType::Sdag;
else if constexpr (std::is_same_v<TWithoutConst, internal::TGateImpl<S>>)
else if constexpr (std::is_same_v<TWithoutConst, internal::TGateImpl<Fp, Sp>>)
return GateType::T;
else if constexpr (std::is_same_v<TWithoutConst, internal::TdagGateImpl<S>>)
else if constexpr (std::is_same_v<TWithoutConst, internal::TdagGateImpl<Fp, Sp>>)
return GateType::Tdag;
else if constexpr (std::is_same_v<TWithoutConst, internal::SqrtXGateImpl<S>>)
else if constexpr (std::is_same_v<TWithoutConst, internal::SqrtXGateImpl<Fp, Sp>>)
return GateType::SqrtX;
else if constexpr (std::is_same_v<TWithoutConst, internal::SqrtXdagGateImpl<S>>)
else if constexpr (std::is_same_v<TWithoutConst, internal::SqrtXdagGateImpl<Fp, Sp>>)
return GateType::SqrtXdag;
else if constexpr (std::is_same_v<TWithoutConst, internal::SqrtYGateImpl<S>>)
else if constexpr (std::is_same_v<TWithoutConst, internal::SqrtYGateImpl<Fp, Sp>>)
return GateType::SqrtY;
else if constexpr (std::is_same_v<TWithoutConst, internal::SqrtYdagGateImpl<S>>)
else if constexpr (std::is_same_v<TWithoutConst, internal::SqrtYdagGateImpl<Fp, Sp>>)
return GateType::SqrtYdag;
else if constexpr (std::is_same_v<TWithoutConst, internal::P0GateImpl<S>>)
else if constexpr (std::is_same_v<TWithoutConst, internal::P0GateImpl<Fp, Sp>>)
return GateType::P0;
else if constexpr (std::is_same_v<TWithoutConst, internal::P1GateImpl<S>>)
else if constexpr (std::is_same_v<TWithoutConst, internal::P1GateImpl<Fp, Sp>>)
return GateType::P1;
else if constexpr (std::is_same_v<TWithoutConst, internal::RXGateImpl<S>>)
else if constexpr (std::is_same_v<TWithoutConst, internal::RXGateImpl<Fp, Sp>>)
return GateType::RX;
else if constexpr (std::is_same_v<TWithoutConst, internal::RYGateImpl<S>>)
else if constexpr (std::is_same_v<TWithoutConst, internal::RYGateImpl<Fp, Sp>>)
return GateType::RY;
else if constexpr (std::is_same_v<TWithoutConst, internal::RZGateImpl<S>>)
else if constexpr (std::is_same_v<TWithoutConst, internal::RZGateImpl<Fp, Sp>>)
return GateType::RZ;
else if constexpr (std::is_same_v<TWithoutConst, internal::U1GateImpl<S>>)
else if constexpr (std::is_same_v<TWithoutConst, internal::U1GateImpl<Fp, Sp>>)
return GateType::U1;
else if constexpr (std::is_same_v<TWithoutConst, internal::U2GateImpl<S>>)
else if constexpr (std::is_same_v<TWithoutConst, internal::U2GateImpl<Fp, Sp>>)
return GateType::U2;
else if constexpr (std::is_same_v<TWithoutConst, internal::U3GateImpl<S>>)
else if constexpr (std::is_same_v<TWithoutConst, internal::U3GateImpl<Fp, Sp>>)
return GateType::U3;
else if constexpr (std::is_same_v<TWithoutConst, internal::SwapGateImpl<S>>)
else if constexpr (std::is_same_v<TWithoutConst, internal::SwapGateImpl<Fp, Sp>>)
return GateType::Swap;
else if constexpr (std::is_same_v<TWithoutConst, internal::PauliGateImpl<S>>)
else if constexpr (std::is_same_v<TWithoutConst, internal::PauliGateImpl<Fp, Sp>>)
return GateType::Pauli;
else if constexpr (std::is_same_v<TWithoutConst, internal::PauliRotationGateImpl<S>>)
else if constexpr (std::is_same_v<TWithoutConst, internal::PauliRotationGateImpl<Fp, Sp>>)
return GateType::PauliRotation;
else if constexpr (std::is_same_v<TWithoutConst, internal::SparseMatrixGateImpl<S>>)
else if constexpr (std::is_same_v<TWithoutConst, internal::SparseMatrixGateImpl<Fp, Sp>>)
return GateType::SparseMatrix;
else if constexpr (std::is_same_v<TWithoutConst, internal::DenseMatrixGateImpl<S>>)
else if constexpr (std::is_same_v<TWithoutConst, internal::DenseMatrixGateImpl<Fp, Sp>>)
return GateType::DenseMatrix;
else if constexpr (std::is_same_v<TWithoutConst, internal::ProbablisticGateImpl<S>>)
else if constexpr (std::is_same_v<TWithoutConst, internal::ProbablisticGateImpl<Fp, Sp>>)
return GateType::Probablistic;
else
static_assert(internal::lazy_false_v<T>, "unknown GateImpl");
}

namespace internal {
// GateBase テンプレートクラス
template <std::floating_point _FloatType>
class GateBase : public std::enable_shared_from_this<GateBase<_FloatType>> {
template <std::floating_point _FloatType, ExecutionSpace _SpaceType>
class GateBase : public std::enable_shared_from_this<GateBase<_FloatType, _SpaceType>> {
public:
using Fp = _FloatType;
using Sp = _SpaceType;

protected:
std::uint64_t _target_mask, _control_mask;

void check_qubit_mask_within_bounds(const StateVector<Fp>& state_vector) const;
void check_qubit_mask_within_bounds(const StateVectorBatched<Fp>& states) const;
void check_qubit_mask_within_bounds(const StateVector<Fp, Sp>& state_vector) const;
void check_qubit_mask_within_bounds(const StateVectorBatched<Fp, Sp>& states) const;

std::string get_qubit_info_as_string(const std::string& indent) const;

Expand All @@ -202,19 +202,19 @@ class GateBase : public std::enable_shared_from_this<GateBase<_FloatType>> {
return _target_mask | _control_mask;
}

[[nodiscard]] virtual std::shared_ptr<const GateBase<Fp>> get_inverse() const = 0;
[[nodiscard]] virtual std::shared_ptr<const GateBase<Fp, Sp>> get_inverse() const = 0;
[[nodiscard]] virtual internal::ComplexMatrix<Fp> get_matrix() const = 0;

virtual void update_quantum_state(StateVector<Fp>& state_vector) const = 0;
virtual void update_quantum_state(StateVectorBatched<Fp>& states) const = 0;
virtual void update_quantum_state(StateVector<Fp, Sp>& state_vector) const = 0;
virtual void update_quantum_state(StateVectorBatched<Fp, Sp>& states) const = 0;

[[nodiscard]] virtual std::string to_string(const std::string& indent = "") const = 0;

virtual void get_as_json(Json& j) const { j = Json{{"type", "Unknown"}}; }
};

template <typename T>
concept GateImpl = std::derived_from<T, GateBase<typename T::Fp>>;
concept GateImpl = std::derived_from<T, GateBase<typename T::Fp, typename T::Sp>>;

template <GateImpl T>
inline std::shared_ptr<const T> get_from_json(const Json&);
Expand All @@ -231,19 +231,20 @@ class GatePtr {

public:
using Fp = typename T::Fp;
GatePtr() : _gate_ptr(nullptr), _gate_type(get_gate_type<T, Fp>()) {}
using Sp = typename T::Sp;
GatePtr() : _gate_ptr(nullptr), _gate_type(get_gate_type<T, Fp, Sp>()) {}
template <GateImpl U>
GatePtr(const std::shared_ptr<const U>& gate_ptr) {
if constexpr (std::is_same_v<T, U>) {
_gate_type = get_gate_type<T, Fp>();
_gate_type = get_gate_type<T, Fp, Sp>();
_gate_ptr = gate_ptr;
} else if constexpr (std::is_same_v<T, internal::GateBase<Fp>>) {
} else if constexpr (std::is_same_v<T, internal::GateBase<Fp, Sp>>) {
// upcast
_gate_type = get_gate_type<U, Fp>();
_gate_type = get_gate_type<U, Fp, Sp>();
_gate_ptr = std::static_pointer_cast<const T>(gate_ptr);
} else {
// downcast
_gate_type = get_gate_type<T, Fp>();
_gate_type = get_gate_type<T, Fp, Sp>();
if (!(_gate_ptr = std::dynamic_pointer_cast<const T>(gate_ptr))) {
throw std::runtime_error("invalid gate cast");
}
Expand All @@ -254,13 +255,13 @@ class GatePtr {
if constexpr (std::is_same_v<T, U>) {
_gate_type = gate._gate_type;
_gate_ptr = gate._gate_ptr;
} else if constexpr (std::is_same_v<T, internal::GateBase<Fp>>) {
} else if constexpr (std::is_same_v<T, internal::GateBase<Fp, Sp>>) {
// upcast
_gate_type = gate._gate_type;
_gate_ptr = std::static_pointer_cast<const T>(gate._gate_ptr);
} else {
// downcast
if (gate._gate_type != get_gate_type<T, Fp>()) {
if (gate._gate_type != get_gate_type<T, Fp, Sp>()) {
throw std::runtime_error("invalid gate cast");
}
_gate_type = gate._gate_type;
Expand Down Expand Up @@ -288,38 +289,38 @@ class GatePtr {
std::string type = j.at("type");

// clang-format off
if (type == "I") gate = get_from_json<IGateImpl<Fp>>(j);
else if (type == "GlobalPhase") gate = get_from_json<GlobalPhaseGateImpl<Fp>>(j);
else if (type == "X") gate = get_from_json<XGateImpl<Fp>>(j);
else if (type == "Y") gate = get_from_json<YGateImpl<Fp>>(j);
else if (type == "Z") gate = get_from_json<ZGateImpl<Fp>>(j);
else if (type == "H") gate = get_from_json<HGateImpl<Fp>>(j);
else if (type == "S") gate = get_from_json<SGateImpl<Fp>>(j);
else if (type == "Sdag") gate = get_from_json<SdagGateImpl<Fp>>(j);
else if (type == "T") gate = get_from_json<TGateImpl<Fp>>(j);
else if (type == "Tdag") gate = get_from_json<TdagGateImpl<Fp>>(j);
else if (type == "SqrtX") gate = get_from_json<SqrtXGateImpl<Fp>>(j);
else if (type == "SqrtXdag") gate = get_from_json<SqrtXdagGateImpl<Fp>>(j);
else if (type == "SqrtY") gate = get_from_json<SqrtYGateImpl<Fp>>(j);
else if (type == "SqrtYdag") gate = get_from_json<SqrtYdagGateImpl<Fp>>(j);
else if (type == "RX") gate = get_from_json<RXGateImpl<Fp>>(j);
else if (type == "RY") gate = get_from_json<RYGateImpl<Fp>>(j);
else if (type == "RZ") gate = get_from_json<RZGateImpl<Fp>>(j);
else if (type == "U1") gate = get_from_json<U1GateImpl<Fp>>(j);
else if (type == "U2") gate = get_from_json<U2GateImpl<Fp>>(j);
else if (type == "U3") gate = get_from_json<U3GateImpl<Fp>>(j);
else if (type == "Swap") gate = get_from_json<SwapGateImpl<Fp>>(j);
else if (type == "Pauli") gate = get_from_json<PauliGateImpl<Fp>>(j);
else if (type == "PauliRotation") gate = get_from_json<PauliRotationGateImpl<Fp>>(j);
else if (type == "Probablistic") gate = get_from_json<ProbablisticGateImpl<Fp>>(j);
if (type == "I") gate = get_from_json<IGateImpl<Fp, Sp>>(j);
else if (type == "GlobalPhase") gate = get_from_json<GlobalPhaseGateImpl<Fp, Sp>>(j);
else if (type == "X") gate = get_from_json<XGateImpl<Fp, Sp>>(j);
else if (type == "Y") gate = get_from_json<YGateImpl<Fp, Sp>>(j);
else if (type == "Z") gate = get_from_json<ZGateImpl<Fp, Sp>>(j);
else if (type == "H") gate = get_from_json<HGateImpl<Fp, Sp>>(j);
else if (type == "S") gate = get_from_json<SGateImpl<Fp, Sp>>(j);
else if (type == "Sdag") gate = get_from_json<SdagGateImpl<Fp, Sp>>(j);
else if (type == "T") gate = get_from_json<TGateImpl<Fp, Sp>>(j);
else if (type == "Tdag") gate = get_from_json<TdagGateImpl<Fp, Sp>>(j);
else if (type == "SqrtX") gate = get_from_json<SqrtXGateImpl<Fp, Sp>>(j);
else if (type == "SqrtXdag") gate = get_from_json<SqrtXdagGateImpl<Fp, Sp>>(j);
else if (type == "SqrtY") gate = get_from_json<SqrtYGateImpl<Fp, Sp>>(j);
else if (type == "SqrtYdag") gate = get_from_json<SqrtYdagGateImpl<Fp, Sp>>(j);
else if (type == "RX") gate = get_from_json<RXGateImpl<Fp, Sp>>(j);
else if (type == "RY") gate = get_from_json<RYGateImpl<Fp, Sp>>(j);
else if (type == "RZ") gate = get_from_json<RZGateImpl<Fp, Sp>>(j);
else if (type == "U1") gate = get_from_json<U1GateImpl<Fp, Sp>>(j);
else if (type == "U2") gate = get_from_json<U2GateImpl<Fp, Sp>>(j);
else if (type == "U3") gate = get_from_json<U3GateImpl<Fp, Sp>>(j);
else if (type == "Swap") gate = get_from_json<SwapGateImpl<Fp, Sp>>(j);
else if (type == "Pauli") gate = get_from_json<PauliGateImpl<Fp, Sp>>(j);
else if (type == "PauliRotation") gate = get_from_json<PauliRotationGateImpl<Fp, Sp>>(j);
else if (type == "Probablistic") gate = get_from_json<ProbablisticGateImpl<Fp, Sp>>(j);
// clang-format on
}
};

} // namespace internal

template <std::floating_point Fp>
using Gate = internal::GatePtr<internal::GateBase<Fp>>;
template <std::floating_point Fp, ExecutionSpace Sp>
using Gate = internal::GatePtr<internal::GateBase<Fp, Sp>>;

#ifdef SCALUQ_USE_NANOBIND
namespace internal {
Expand Down Expand Up @@ -383,8 +384,8 @@ namespace internal {
}, \
"Read an object from the JSON representation of the gate.")

template <std::floating_point Fp>
nb::class_<Gate<Fp>> gate_base_def;
template <std::floating_point Fp, ExecutionSpace sp>
nb::class_<Gate<Fp, Sp>> gate_base_def;

#define DEF_GATE(GATE_TYPE, FLOAT, DESCRIPTION) \
::scaluq::internal::gate_base_def<FLOAT>.def(nb::init<GATE_TYPE<FLOAT>>(), \
Expand Down Expand Up @@ -427,14 +428,14 @@ void bind_gate_gate_hpp_without_precision(nb::module_& m) {
.value("DenseMatrix", GateType::DenseMatrix);
}

template <std::floating_point Fp>
template <std::floating_point Fp, ExecutionSpace sp>
void bind_gate_gate_hpp(nb::module_& m) {
gate_base_def<Fp> =
gate_base_def<Fp, Sp> =
DEF_GATE_BASE(Gate,
Fp,
"General class of QuantumGate.\n\n.. note:: Downcast to requred to use "
"gate-specific functions.")
.def(nb::init<Gate<Fp>>(), "Just copy shallowly.");
.def(nb::init<Gate<Fp, Sp>>(), "Just copy shallowly.");
}
} // namespace internal
#endif
Expand Down
2 changes: 1 addition & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ target_sources(scaluq PRIVATE
# gate/gate_pauli.cpp
# gate/gate_probablistic.cpp
# gate/gate_standard.cpp
# gate/gate.cpp
gate/gate.cpp
# gate/param_gate.cpp
# gate/param_gate_pauli.cpp
# gate/param_gate_probablistic.cpp
Expand Down
Loading

0 comments on commit 0989281

Please sign in to comment.