From 736516d60ab292d4f0dfd51bc39b0c738d5717b5 Mon Sep 17 00:00:00 2001 From: gandalfr-KY Date: Fri, 4 Oct 2024 00:37:53 +0000 Subject: [PATCH] only float template --- exe/main.cpp | 65 ++++++++++----------------- scaluq/state/state_vector.hpp | 16 ++++--- scaluq/state/state_vector_batched.hpp | 19 ++++---- 3 files changed, 42 insertions(+), 58 deletions(-) diff --git a/exe/main.cpp b/exe/main.cpp index 5dac61be..d8af8bde 100644 --- a/exe/main.cpp +++ b/exe/main.cpp @@ -14,35 +14,23 @@ namespace internal { enum class GateType { Unknown, X }; -template class XGateImpl; -template -inline constexpr bool lazy_false_v = false; - template constexpr GateType get_gate_type() { - if constexpr (std::is_same_v> || - std::is_same_v> || - std::is_same_v> || - std::is_same_v>) { + if constexpr (std::is_same_v) { return GateType::X; } else { static_assert(lazy_false_v, "unknown GateImpl"); } } -// GateBase テンプレートクラス -template -class GateBase : public std::enable_shared_from_this> { -public: - using FloatType = _FloatType; - using Space = _Space; - +class GateBase : public std::enable_shared_from_this { protected: std::uint64_t _target_mask, _control_mask; - void check_qubit_mask_within_bounds(const StateVector& state_vector) const { + template + void check_qubit_mask_within_bounds(const StateVector& state_vector) const { std::uint64_t full_mask = (1ULL << state_vector.n_qubits()) - 1; if ((_target_mask | _control_mask) > full_mask) [[unlikely]] { throw std::runtime_error( @@ -92,22 +80,17 @@ class GateBase : public std::enable_shared_from_this& state_vector) const = 0; + virtual void update_quantum_state(StateVector& state_vector) const = 0; + virtual void update_quantum_state(StateVector& state_vector) const = 0; [[nodiscard]] virtual std::string to_string(const std::string& indent = "") const = 0; }; template -concept GateImpl = std::derived_from>; +concept GateImpl = std::derived_from; template class GatePtr { - using FloatType = T::FloatType; - using Space = T::Space; - - static_assert(std::derived_from>, - "T must derive from GateBase"); - private: std::shared_ptr _gate_ptr; GateType _gate_type; @@ -166,13 +149,10 @@ class GatePtr { } }; -template -using Gate = GatePtr>; +using Gate = GatePtr; -template -void x_gate(std::uint64_t target_mask, - std::uint64_t control_mask, - StateVector& state) { +template +void x_gate(std::uint64_t target_mask, std::uint64_t control_mask, StateVector& state) { Kokkos::parallel_for( state.dim() >> std::popcount(target_mask | control_mask), KOKKOS_LAMBDA(std::uint64_t it) { std::uint64_t i = @@ -182,12 +162,16 @@ void x_gate(std::uint64_t target_mask, Kokkos::fence(); } -template -class XGateImpl : public GateBase { +class XGateImpl : public GateBase { public: - using GateBase::GateBase; + using GateBase::GateBase; + + void update_quantum_state(StateVector& state_vector) const override { + this->check_qubit_mask_within_bounds(state_vector); + x_gate(this->_target_mask, this->_control_mask, state_vector); + } - void update_quantum_state(StateVector& state_vector) const override { + void update_quantum_state(StateVector& state_vector) const override { this->check_qubit_mask_within_bounds(state_vector); x_gate(this->_target_mask, this->_control_mask, state_vector); } @@ -202,7 +186,7 @@ class XGateImpl : public GateBase { class GateFactory { public: template - static internal::Gate create_gate(Args... args) { + static internal::Gate create_gate(Args... args) { return {std::make_shared(args...)}; } }; @@ -211,10 +195,9 @@ class GateFactory { namespace gate { -template -inline internal::Gate X(std::uint64_t target, - const std::vector& control_qubits = {}) { - return internal::GateFactory::create_gate>( +inline internal::Gate X(std::uint64_t target, + const std::vector& control_qubits = {}) { + return internal::GateFactory::create_gate( internal::vector_to_mask({target}), internal::vector_to_mask(control_qubits)); } @@ -226,9 +209,9 @@ int main() { Kokkos::initialize(); { std::uint64_t n_qubits = 3; - scaluq::StateVector state(n_qubits); + scaluq::StateVector state(n_qubits); state.load({0, 1, 2, 3, 4, 5, 6, 7}); - auto x_gate = scaluq::gate::X(1, {0, 2}); + auto x_gate = scaluq::gate::X(1, {0, 2}); x_gate->update_quantum_state(state); std::cout << state << std::endl; diff --git a/scaluq/state/state_vector.hpp b/scaluq/state/state_vector.hpp index 96de4482..a1bc87f4 100644 --- a/scaluq/state/state_vector.hpp +++ b/scaluq/state/state_vector.hpp @@ -15,21 +15,23 @@ namespace scaluq { using HostSpace = Kokkos::HostSpace; using DefaultSpace = Kokkos::DefaultExecutionSpace; -#define STATE_VECTOR_TEMPLATE(FloatType, Space) \ - template +// #define STATE_VECTOR_TEMPLATE(FloatType, Space) \ +// template -STATE_VECTOR_TEMPLATE(FloatType, Space) +#define STATE_VECTOR_TEMPLATE(FloatType) template + +template class StateVector { std::uint64_t _n_qubits; std::uint64_t _dim; using ComplexType = Kokkos::complex; - static_assert(std::is_same_v || std::is_same_v, - "Unsupported execution space tag"); + // static_assert(std::is_same_v || std::is_same_v, + // "Unsupported execution space tag"); public: static constexpr std::uint64_t UNMEASURED = 2; - Kokkos::View _raw; + Kokkos::View _raw; StateVector() = default; StateVector(std::uint64_t n_qubits) : _n_qubits(n_qubits), @@ -62,7 +64,7 @@ class StateVector { [[nodiscard]] static StateVector Haar_random_state( std::uint64_t n_qubits, std::uint64_t seed = std::random_device()()) { Kokkos::Random_XorShift64_Pool<> rand_pool(seed); - StateVector state(n_qubits); + StateVector state(n_qubits); Kokkos::parallel_for( state._dim, KOKKOS_LAMBDA(std::uint64_t i) { auto rand_gen = rand_pool.get_state(); diff --git a/scaluq/state/state_vector_batched.hpp b/scaluq/state/state_vector_batched.hpp index c29ebbd7..523f9406 100644 --- a/scaluq/state/state_vector_batched.hpp +++ b/scaluq/state/state_vector_batched.hpp @@ -5,15 +5,15 @@ namespace scaluq { -STATE_VECTOR_TEMPLATE(FloatType, Space) +STATE_VECTOR_TEMPLATE(FloatType) class StateVectorBatched { std::uint64_t _batch_size; std::uint64_t _n_qubits; std::uint64_t _dim; using ComplexType = Kokkos::complex; - static_assert(std::is_same_v || std::is_same_v, - "Unsupported execution space tag"); + // static_assert(std::is_same_v || std::is_same_v, + // "Unsupported execution space tag"); public: Kokkos::View _raw; @@ -36,7 +36,7 @@ class StateVectorBatched { [[nodiscard]] std::uint64_t batch_size() const { return this->_batch_size; } - void set_state_vector(const StateVector& state) { + void set_state_vector(const StateVector& state) { if (_raw.extent(1) != state._raw.extent(0)) [[unlikely]] { throw std::runtime_error( "Error: StateVectorBatched::set_state_vector(const StateVector&): Dimensions of " @@ -50,7 +50,7 @@ class StateVectorBatched { Kokkos::fence(); } - void set_state_vector_at(std::uint64_t batch_id, const StateVector& state) { + void set_state_vector_at(std::uint64_t batch_id, const StateVector& state) { if (_raw.extent(1) != state._raw.extent(0)) [[unlikely]] { throw std::runtime_error( "Error: StateVectorBatched::set_state_vector(std::uint64_t, const StateVector&): " @@ -61,8 +61,8 @@ class StateVectorBatched { Kokkos::fence(); } - [[nodiscard]] StateVector get_state_vector_at(std::uint64_t batch_id) const { - StateVector ret(_n_qubits); + [[nodiscard]] StateVector get_state_vector_at(std::uint64_t batch_id) const { + StateVector ret(_n_qubits); Kokkos::parallel_for( _dim, KOKKOS_CLASS_LAMBDA(std::uint64_t i) { ret._raw(i) = _raw(batch_id, i); }); Kokkos::fence(); @@ -154,8 +154,7 @@ class StateVectorBatched { Kokkos::Random_XorShift64_Pool<> rand_pool(seed); StateVectorBatched states(batch_size, n_qubits); if (set_same_state) { - states.set_state_vector( - StateVector::Haar_random_state(n_qubits, seed)); + states.set_state_vector(StateVector::Haar_random_state(n_qubits, seed)); } else { Kokkos::parallel_for( Kokkos::MDRangePolicy>({0, 0}, {states.batch_size(), states.dim()}), @@ -273,7 +272,7 @@ class StateVectorBatched { if (measured_value == 0 || measured_value == 1) { target_index.push_back(i); target_value.push_back(measured_value); - } else if (measured_value != StateVector::UNMEASURED) { + } else if (measured_value != StateVector::UNMEASURED) { throw std::runtime_error( "Error:StateVectorBatched::get_marginal_probability(const " "vector&): Invalid qubit state specified. Each qubit state must "