From 6d73cccfa726c7ca91c98855d7068451f201d23c Mon Sep 17 00:00:00 2001 From: Glacialte Date: Wed, 18 Sep 2024 03:39:13 +0000 Subject: [PATCH] add batch update --- scaluq/gate/gate.hpp | 13 +- scaluq/gate/gate_matrix.hpp | 10 +- scaluq/gate/gate_pauli.hpp | 11 +- scaluq/gate/gate_probablistic.hpp | 33 +++- scaluq/gate/gate_standard.hpp | 114 ++++++++++--- scaluq/gate/param_gate_pauli.hpp | 7 +- scaluq/gate/param_gate_probablistic.hpp | 38 ++++- scaluq/gate/param_gate_standard.hpp | 15 +- scaluq/gate/update_ops_dense_matrix.cpp | 22 +++ scaluq/gate/update_ops_standard.cpp | 203 ++++++++++++++++++++++++ scaluq/operator/apply_pauli.cpp | 99 ++++++++++++ scaluq/operator/apply_pauli.hpp | 12 ++ 12 files changed, 530 insertions(+), 47 deletions(-) diff --git a/scaluq/gate/gate.hpp b/scaluq/gate/gate.hpp index c2c3177d..d3cbbef1 100644 --- a/scaluq/gate/gate.hpp +++ b/scaluq/gate/gate.hpp @@ -158,7 +158,18 @@ class GateBase : public std::enable_shared_from_this { } } - std::string get_qubit_info_as_string(const std::string& indent) const { + [[nodiscard]] std::vector mask_to_vector(std::uint64_t mask) const { + std::vector qubits; + for (std::uint64_t i = 0; i < 64; ++i) { + if ((mask >> i) & 1) qubits.push_back(i); + } + return qubits; + } + + [[nodiscard]] + + std::string + get_qubit_info_as_string(const std::string& indent) const { std::ostringstream ss; auto targets = target_qubit_list(); auto controls = control_qubit_list(); diff --git a/scaluq/gate/gate_matrix.hpp b/scaluq/gate/gate_matrix.hpp index be9a54ee..dfe02065 100644 --- a/scaluq/gate/gate_matrix.hpp +++ b/scaluq/gate/gate_matrix.hpp @@ -48,7 +48,10 @@ class OneTargetMatrixGateImpl : public GateBase { one_target_dense_matrix_gate(_target_mask, _control_mask, _matrix, state_vector); } - void update_quantum_state(StateVectorBatched& states) const override {} + void update_quantum_state(StateVectorBatched& states) const override { + check_qubit_mask_within_bounds(states.get_state_vector_at(0)); + one_target_dense_matrix_gate(_target_mask, _control_mask, _matrix, states); + } std::string to_string(const std::string& indent) const override { std::ostringstream ss; @@ -107,7 +110,10 @@ class TwoTargetMatrixGateImpl : public GateBase { two_target_dense_matrix_gate(_target_mask, _control_mask, _matrix, state_vector); } - void update_quantum_state(StateVectorBatched& states) const override {} + void update_quantum_state(StateVectorBatched& states) const override { + check_qubit_mask_within_bounds(states.get_state_vector_at(0)); + two_target_dense_matrix_gate(_target_mask, _control_mask, _matrix, states); + } std::string to_string(const std::string& indent) const override { std::ostringstream ss; diff --git a/scaluq/gate/gate_pauli.hpp b/scaluq/gate/gate_pauli.hpp index adcd3496..a8cc5d19 100644 --- a/scaluq/gate/gate_pauli.hpp +++ b/scaluq/gate/gate_pauli.hpp @@ -27,7 +27,10 @@ class PauliGateImpl : public GateBase { apply_pauli(_control_mask, bit_flip_mask, phase_flip_mask, _pauli.coef(), state_vector); } - void update_quantum_state(StateVectorBatched& states) const override {} + void update_quantum_state(StateVectorBatched& states) const override { + auto [bit_flip_mask, phase_flip_mask] = _pauli.get_XZ_mask_representation(); + apply_pauli(_control_mask, bit_flip_mask, phase_flip_mask, _pauli.coef(), states); + } std::string to_string(const std::string& indent) const override { std::ostringstream ss; @@ -75,7 +78,11 @@ class PauliRotationGateImpl : public GateBase { _control_mask, bit_flip_mask, phase_flip_mask, _pauli.coef(), _angle, state_vector); } - void update_quantum_state(StateVectorBatched& states) const override {} + void update_quantum_state(StateVectorBatched& states) const override { + auto [bit_flip_mask, phase_flip_mask] = _pauli.get_XZ_mask_representation(); + apply_pauli_rotation( + _control_mask, bit_flip_mask, phase_flip_mask, _pauli.coef(), _angle, states); + } std::string to_string(const std::string& indent) const override { std::ostringstream ss; diff --git a/scaluq/gate/gate_probablistic.hpp b/scaluq/gate/gate_probablistic.hpp index 5c6fa0ad..2792e079 100644 --- a/scaluq/gate/gate_probablistic.hpp +++ b/scaluq/gate/gate_probablistic.hpp @@ -7,7 +7,7 @@ namespace scaluq { namespace internal { class ProbablisticGateImpl : public GateBase { std::vector _distribution; - std::vector _cumlative_distribution; + std::vector _cumulative_distribution; std::vector _gate_list; public: @@ -21,10 +21,10 @@ class ProbablisticGateImpl : public GateBase { if (n != gate_list.size()) { throw std::runtime_error("distribution and gate_list have different size."); } - _cumlative_distribution.resize(n + 1); + _cumulative_distribution.resize(n + 1); std::partial_sum( - distribution.begin(), distribution.end(), _cumlative_distribution.begin() + 1); - if (std::abs(_cumlative_distribution.back() - 1.) > 1e-6) { + distribution.begin(), distribution.end(), _cumulative_distribution.begin() + 1); + if (std::abs(_cumulative_distribution.back() - 1.) > 1e-6) { throw std::runtime_error("Sum of distribution must be equal to 1."); } } @@ -79,14 +79,33 @@ class ProbablisticGateImpl : public GateBase { void update_quantum_state(StateVector& state_vector) const override { Random random; double r = random.uniform(); - std::uint64_t i = std::distance(_cumlative_distribution.begin(), - std::ranges::upper_bound(_cumlative_distribution, r)) - + std::uint64_t i = std::distance(_cumulative_distribution.begin(), + std::ranges::upper_bound(_cumulative_distribution, r)) - 1; if (i >= _gate_list.size()) i = _gate_list.size() - 1; _gate_list[i]->update_quantum_state(state_vector); } - void update_quantum_state(StateVectorBatched& states) const override {} + // これでいいのか?わからない... + void update_quantum_state(StateVectorBatched& states) const override { + Random random; + std::vector r(states.batch_size()); + std::ranges::generate(r, [&random]() { return random.uniform(); }); + std::vector indicies(states.batch_size()); + std::ranges::transform(r, indicies.begin(), [this](double r) { + return std::distance(_cumulative_distribution.begin(), + std::ranges::upper_bound(_cumulative_distribution, r)) - + 1; + }); + std::ranges::transform(indicies, indicies.begin(), [this](std::uint64_t i) { + if (i >= _gate_list.size()) i = _gate_list.size() - 1; + return i; + }); + for (std::size_t i = 0; i < states.batch_size(); ++i) { + auto state_vector = states.get_state_vector_at(i); + _gate_list[indicies[i]]->update_quantum_state(state_vector); + } + } std::string to_string(const std::string& indent) const override { std::ostringstream ss; diff --git a/scaluq/gate/gate_standard.hpp b/scaluq/gate/gate_standard.hpp index b2914fb8..abf4d64e 100644 --- a/scaluq/gate/gate_standard.hpp +++ b/scaluq/gate/gate_standard.hpp @@ -19,7 +19,9 @@ class IGateImpl : public GateBase { i_gate(_target_mask, _control_mask, state_vector); } - void update_quantum_state(StateVectorBatched& states) const override {} + void update_quantum_state(StateVectorBatched& states) const override { + i_gate(_target_mask, _control_mask, states); + } std::string to_string(const std::string& indent) const override { std::ostringstream ss; @@ -51,7 +53,10 @@ class GlobalPhaseGateImpl : public GateBase { global_phase_gate(_target_mask, _control_mask, _phase, state_vector); } - void update_quantum_state(StateVectorBatched& states) const override {} + void update_quantum_state(StateVectorBatched& states) const override { + check_qubit_mask_within_bounds(states.get_state_vector_at(0)); + global_phase_gate(_target_mask, _control_mask, _phase, states); + } std::string to_string(const std::string& indent) const override { std::ostringstream ss; @@ -89,7 +94,10 @@ class XGateImpl : public GateBase { x_gate(_target_mask, _control_mask, state_vector); } - void update_quantum_state(StateVectorBatched& states) const override {} + void update_quantum_state(StateVectorBatched& states) const override { + check_qubit_mask_within_bounds(states.get_state_vector_at(0)); + x_gate(_target_mask, _control_mask, states); + } std::string to_string(const std::string& indent) const override { std::ostringstream ss; @@ -115,7 +123,10 @@ class YGateImpl : public GateBase { y_gate(_target_mask, _control_mask, state_vector); } - void update_quantum_state(StateVectorBatched& states) const override {} + void update_quantum_state(StateVectorBatched& states) const override { + check_qubit_mask_within_bounds(states.get_state_vector_at(0)); + y_gate(_target_mask, _control_mask, states); + } std::string to_string(const std::string& indent) const override { std::ostringstream ss; @@ -141,7 +152,10 @@ class ZGateImpl : public GateBase { z_gate(_target_mask, _control_mask, state_vector); } - void update_quantum_state(StateVectorBatched& states) const override {} + void update_quantum_state(StateVectorBatched& states) const override { + check_qubit_mask_within_bounds(states.get_state_vector_at(0)); + z_gate(_target_mask, _control_mask, states); + } std::string to_string(const std::string& indent) const override { std::ostringstream ss; @@ -168,7 +182,10 @@ class HGateImpl : public GateBase { h_gate(_target_mask, _control_mask, state_vector); } - void update_quantum_state(StateVectorBatched& states) const override {} + void update_quantum_state(StateVectorBatched& states) const override { + check_qubit_mask_within_bounds(states.get_state_vector_at(0)); + h_gate(_target_mask, _control_mask, states); + } std::string to_string(const std::string& indent) const override { std::ostringstream ss; @@ -203,7 +220,10 @@ class SGateImpl : public GateBase { s_gate(_target_mask, _control_mask, state_vector); } - void update_quantum_state(StateVectorBatched& states) const override {} + void update_quantum_state(StateVectorBatched& states) const override { + check_qubit_mask_within_bounds(states.get_state_vector_at(0)); + s_gate(_target_mask, _control_mask, states); + } std::string to_string(const std::string& indent) const override { std::ostringstream ss; @@ -231,7 +251,10 @@ class SdagGateImpl : public GateBase { sdag_gate(_target_mask, _control_mask, state_vector); } - void update_quantum_state(StateVectorBatched& states) const override {} + void update_quantum_state(StateVectorBatched& states) const override { + check_qubit_mask_within_bounds(states.get_state_vector_at(0)); + sdag_gate(_target_mask, _control_mask, states); + } std::string to_string(const std::string& indent) const override { std::ostringstream ss; @@ -261,7 +284,10 @@ class TGateImpl : public GateBase { t_gate(_target_mask, _control_mask, state_vector); } - void update_quantum_state(StateVectorBatched& states) const override {} + void update_quantum_state(StateVectorBatched& states) const override { + check_qubit_mask_within_bounds(states.get_state_vector_at(0)); + t_gate(_target_mask, _control_mask, states); + } std::string to_string(const std::string& indent) const override { std::ostringstream ss; @@ -289,7 +315,10 @@ class TdagGateImpl : public GateBase { tdag_gate(_target_mask, _control_mask, state_vector); } - void update_quantum_state(StateVectorBatched& states) const override {} + void update_quantum_state(StateVectorBatched& states) const override { + check_qubit_mask_within_bounds(states.get_state_vector_at(0)); + tdag_gate(_target_mask, _control_mask, states); + } std::string to_string(const std::string& indent) const override { std::ostringstream ss; @@ -319,7 +348,10 @@ class SqrtXGateImpl : public GateBase { sqrtx_gate(_target_mask, _control_mask, state_vector); } - void update_quantum_state(StateVectorBatched& states) const override {} + void update_quantum_state(StateVectorBatched& states) const override { + check_qubit_mask_within_bounds(states.get_state_vector_at(0)); + sqrtx_gate(_target_mask, _control_mask, states); + } std::string to_string(const std::string& indent) const override { std::ostringstream ss; @@ -347,7 +379,10 @@ class SqrtXdagGateImpl : public GateBase { sqrtxdag_gate(_target_mask, _control_mask, state_vector); } - void update_quantum_state(StateVectorBatched& states) const override {} + void update_quantum_state(StateVectorBatched& states) const override { + check_qubit_mask_within_bounds(states.get_state_vector_at(0)); + sqrtxdag_gate(_target_mask, _control_mask, states); + } std::string to_string(const std::string& indent) const override { std::ostringstream ss; @@ -377,7 +412,10 @@ class SqrtYGateImpl : public GateBase { sqrty_gate(_target_mask, _control_mask, state_vector); } - void update_quantum_state(StateVectorBatched& states) const override {} + void update_quantum_state(StateVectorBatched& states) const override { + check_qubit_mask_within_bounds(states.get_state_vector_at(0)); + sqrty_gate(_target_mask, _control_mask, states); + } std::string to_string(const std::string& indent) const override { std::ostringstream ss; @@ -405,7 +443,10 @@ class SqrtYdagGateImpl : public GateBase { sqrtydag_gate(_target_mask, _control_mask, state_vector); } - void update_quantum_state(StateVectorBatched& states) const override {} + void update_quantum_state(StateVectorBatched& states) const override { + check_qubit_mask_within_bounds(states.get_state_vector_at(0)); + sqrtydag_gate(_target_mask, _control_mask, states); + } std::string to_string(const std::string& indent) const override { std::ostringstream ss; @@ -437,7 +478,10 @@ class P0GateImpl : public GateBase { p0_gate(_target_mask, _control_mask, state_vector); } - void update_quantum_state(StateVectorBatched& states) const override {} + void update_quantum_state(StateVectorBatched& states) const override { + check_qubit_mask_within_bounds(states.get_state_vector_at(0)); + p0_gate(_target_mask, _control_mask, states); + } std::string to_string(const std::string& indent) const override { std::ostringstream ss; @@ -465,7 +509,10 @@ class P1GateImpl : public GateBase { p1_gate(_target_mask, _control_mask, state_vector); } - void update_quantum_state(StateVectorBatched& states) const override {} + void update_quantum_state(StateVectorBatched& states) const override { + check_qubit_mask_within_bounds(states.get_state_vector_at(0)); + p1_gate(_target_mask, _control_mask, states); + } std::string to_string(const std::string& indent) const override { std::ostringstream ss; @@ -494,7 +541,10 @@ class RXGateImpl : public RotationGateBase { rx_gate(_target_mask, _control_mask, _angle, state_vector); } - void update_quantum_state(StateVectorBatched& states) const override {} + void update_quantum_state(StateVectorBatched& states) const override { + check_qubit_mask_within_bounds(states.get_state_vector_at(0)); + rx_gate(_target_mask, _control_mask, _angle, states); + } std::string to_string(const std::string& indent) const override { std::ostringstream ss; @@ -524,7 +574,10 @@ class RYGateImpl : public RotationGateBase { ry_gate(_target_mask, _control_mask, _angle, state_vector); } - void update_quantum_state(StateVectorBatched& states) const override {} + void update_quantum_state(StateVectorBatched& states) const override { + check_qubit_mask_within_bounds(states.get_state_vector_at(0)); + ry_gate(_target_mask, _control_mask, _angle, states); + } std::string to_string(const std::string& indent) const override { std::ostringstream ss; @@ -553,7 +606,10 @@ class RZGateImpl : public RotationGateBase { rz_gate(_target_mask, _control_mask, _angle, state_vector); } - void update_quantum_state(StateVectorBatched& states) const override {} + void update_quantum_state(StateVectorBatched& states) const override { + check_qubit_mask_within_bounds(states.get_state_vector_at(0)); + rz_gate(_target_mask, _control_mask, _angle, states); + } std::string to_string(const std::string& indent) const override { std::ostringstream ss; @@ -587,7 +643,10 @@ class U1GateImpl : public GateBase { u1_gate(_target_mask, _control_mask, _lambda, state_vector); } - void update_quantum_state(StateVectorBatched& states) const override {} + void update_quantum_state(StateVectorBatched& states) const override { + check_qubit_mask_within_bounds(states.get_state_vector_at(0)); + u1_gate(_target_mask, _control_mask, _lambda, states); + } std::string to_string(const std::string& indent) const override { std::ostringstream ss; @@ -626,7 +685,10 @@ class U2GateImpl : public GateBase { u2_gate(_target_mask, _control_mask, _phi, _lambda, state_vector); } - void update_quantum_state(StateVectorBatched& states) const override {} + void update_quantum_state(StateVectorBatched& states) const override { + check_qubit_mask_within_bounds(states.get_state_vector_at(0)); + u2_gate(_target_mask, _control_mask, _phi, _lambda, states); + } std::string to_string(const std::string& indent) const override { std::ostringstream ss; @@ -668,7 +730,10 @@ class U3GateImpl : public GateBase { u3_gate(_target_mask, _control_mask, _theta, _phi, _lambda, state_vector); } - void update_quantum_state(StateVectorBatched& states) const override {} + void update_quantum_state(StateVectorBatched& states) const override { + check_qubit_mask_within_bounds(states.get_state_vector_at(0)); + u3_gate(_target_mask, _control_mask, _theta, _phi, _lambda, states); + } std::string to_string(const std::string& indent) const override { std::ostringstream ss; @@ -694,7 +759,10 @@ class SwapGateImpl : public GateBase { swap_gate(_target_mask, _control_mask, state_vector); } - void update_quantum_state(StateVectorBatched& states) const override {} + void update_quantum_state(StateVectorBatched& states) const override { + check_qubit_mask_within_bounds(states.get_state_vector_at(0)); + swap_gate(_target_mask, _control_mask, states); + } std::string to_string(const std::string& indent) const override { std::ostringstream ss; diff --git a/scaluq/gate/param_gate_pauli.hpp b/scaluq/gate/param_gate_pauli.hpp index 8d77491d..0557f48c 100644 --- a/scaluq/gate/param_gate_pauli.hpp +++ b/scaluq/gate/param_gate_pauli.hpp @@ -44,8 +44,11 @@ class ParamPauliRotationGateImpl : public ParamGateBase { _pcoef * param, state_vector); } - - void update_quantum_state(StateVectorBatched& states, double param) const override {} + void update_quantum_state(StateVectorBatched& states, double param) const override { + auto [bit_flip_mask, phase_flip_mask] = _pauli.get_XZ_mask_representation(); + apply_pauli_rotation( + _control_mask, bit_flip_mask, phase_flip_mask, _pauli.coef(), _pcoef * param, states); + } std::string to_string(const std::string& indent) const override { std::ostringstream ss; diff --git a/scaluq/gate/param_gate_probablistic.hpp b/scaluq/gate/param_gate_probablistic.hpp index 46ef1186..5947e37d 100644 --- a/scaluq/gate/param_gate_probablistic.hpp +++ b/scaluq/gate/param_gate_probablistic.hpp @@ -11,7 +11,7 @@ namespace internal { class ParamProbablisticGateImpl : public ParamGateBase { using EitherGate = std::variant; std::vector _distribution; - std::vector _cumlative_distribution; + std::vector _cumulative_distribution; std::vector _gate_list; public: @@ -25,10 +25,10 @@ class ParamProbablisticGateImpl : public ParamGateBase { if (n != gate_list.size()) { throw std::runtime_error("distribution and gate_list have different size."); } - _cumlative_distribution.resize(n + 1); + _cumulative_distribution.resize(n + 1); std::partial_sum( - distribution.begin(), distribution.end(), _cumlative_distribution.begin() + 1); - if (std::abs(_cumlative_distribution.back() - 1.) > 1e-6) { + distribution.begin(), distribution.end(), _cumulative_distribution.begin() + 1); + if (std::abs(_cumulative_distribution.back() - 1.) > 1e-6) { throw std::runtime_error("Sum of distribution must be equal to 1."); } } @@ -84,8 +84,8 @@ class ParamProbablisticGateImpl : public ParamGateBase { void update_quantum_state(StateVector& state_vector, double param) const override { Random random; double r = random.uniform(); - std::uint64_t i = std::distance(_cumlative_distribution.begin(), - std::ranges::upper_bound(_cumlative_distribution, r)) - + std::uint64_t i = std::distance(_cumulative_distribution.begin(), + std::ranges::upper_bound(_cumulative_distribution, r)) - 1; if (i >= _gate_list.size()) i = _gate_list.size() - 1; const auto& gate = _gate_list[i]; @@ -96,7 +96,31 @@ class ParamProbablisticGateImpl : public ParamGateBase { } } - void update_quantum_state(StateVectorBatched& states, double param) const override {} + // これでいいのか?分からない... + void update_quantum_state(StateVectorBatched& states, double param) const override { + Random random; + std::vector r(states.batch_size()); + std::ranges::generate(r, [&random]() { return random.uniform(); }); + std::vector indicies(states.batch_size()); + std::ranges::transform(r, indicies.begin(), [this](double r) { + return std::distance(_cumulative_distribution.begin(), + std::ranges::upper_bound(_cumulative_distribution, r)) - + 1; + }); + std::ranges::transform(indicies, indicies.begin(), [this](std::uint64_t i) { + if (i >= _gate_list.size()) i = _gate_list.size() - 1; + return i; + }); + for (std::size_t i = 0; i < states.batch_size(); ++i) { + const auto& gate = _gate_list[indicies[i]]; + auto state_vector = states.get_state_vector_at(i); + if (gate.index() == 0) { + std::get<0>(gate)->update_quantum_state(state_vector); + } else { + std::get<1>(gate)->update_quantum_state(state_vector, param); + } + } + } std::string to_string(const std::string& indent) const override { std::ostringstream ss; diff --git a/scaluq/gate/param_gate_standard.hpp b/scaluq/gate/param_gate_standard.hpp index d3094a16..7e1f6470 100644 --- a/scaluq/gate/param_gate_standard.hpp +++ b/scaluq/gate/param_gate_standard.hpp @@ -28,7 +28,10 @@ class ParamRXGateImpl : public ParamGateBase { rx_gate(_target_mask, _control_mask, _pcoef * param, state_vector); } - void update_quantum_state(StateVectorBatched& states, double param) const override {} + void update_quantum_state(StateVectorBatched& states, double param) const override { + check_qubit_mask_within_bounds(states.get_state_vector_at(0)); + rx_gate(_target_mask, _control_mask, _pcoef * param, states); + } std::string to_string(const std::string& indent) const override { std::ostringstream ss; @@ -57,7 +60,10 @@ class ParamRYGateImpl : public ParamGateBase { ry_gate(_target_mask, _control_mask, _pcoef * param, state_vector); } - void update_quantum_state(StateVectorBatched& states, double param) const override {} + void update_quantum_state(StateVectorBatched& states, double param) const override { + check_qubit_mask_within_bounds(states.get_state_vector_at(0)); + ry_gate(_target_mask, _control_mask, _pcoef * param, states); + } std::string to_string(const std::string& indent) const override { std::ostringstream ss; @@ -86,7 +92,10 @@ class ParamRZGateImpl : public ParamGateBase { rz_gate(_target_mask, _control_mask, _pcoef * param, state_vector); } - void update_quantum_state(StateVectorBatched& states, double param) const override {} + void update_quantum_state(StateVectorBatched& states, double param) const override { + check_qubit_mask_within_bounds(states.get_state_vector_at(0)); + rz_gate(_target_mask, _control_mask, _pcoef * param, states); + } std::string to_string(const std::string& indent) const override { std::ostringstream ss; diff --git a/scaluq/gate/update_ops_dense_matrix.cpp b/scaluq/gate/update_ops_dense_matrix.cpp index ca7141eb..d48bf06b 100644 --- a/scaluq/gate/update_ops_dense_matrix.cpp +++ b/scaluq/gate/update_ops_dense_matrix.cpp @@ -26,6 +26,28 @@ void one_target_dense_matrix_gate(std::uint64_t target_mask, Kokkos::fence(); } +void one_target_dense_matrix_gate(std::uint64_t target_mask, + std::uint64_t control_mask, + const Matrix2x2& matrix, + StateVectorBatched& states) { + Kokkos::parallel_for( + Kokkos::MDRangePolicy>( + {0, 0}, + {states.batch_size(), states.dim() >> std::popcount(target_mask | control_mask)}), + KOKKOS_LAMBDA(std::uint64_t batch_id, std::uint64_t it) { + std::uint64_t basis_0 = + insert_zero_at_mask_positions(it, control_mask | target_mask) | control_mask; + std::uint64_t basis_1 = basis_0 | target_mask; + Complex val0 = states._raw(batch_id, basis_0); + Complex val1 = states._raw(batch_id, basis_1); + Complex res0 = matrix[0][0] * val0 + matrix[0][1] * val1; + Complex res1 = matrix[1][0] * val0 + matrix[1][1] * val1; + states._raw(batch_id, basis_0) = res0; + states._raw(batch_id, basis_1) = res1; + }); + Kokkos::fence(); +} + void two_target_dense_matrix_gate(std::uint64_t target_mask, std::uint64_t control_mask, const Matrix4x4& matrix, diff --git a/scaluq/gate/update_ops_standard.cpp b/scaluq/gate/update_ops_standard.cpp index 1e64854b..98bc4a2d 100644 --- a/scaluq/gate/update_ops_standard.cpp +++ b/scaluq/gate/update_ops_standard.cpp @@ -9,6 +9,7 @@ namespace scaluq { namespace internal { void i_gate(std::uint64_t, std::uint64_t, StateVector&) {} +void i_gate(std::uint64_t, std::uint64_t, StateVectorBatched&) {} void global_phase_gate(std::uint64_t, std::uint64_t control_mask, @@ -22,6 +23,21 @@ void global_phase_gate(std::uint64_t, Kokkos::fence(); } +void global_phase_gate(std::uint64_t, + std::uint64_t control_mask, + double phase, + StateVectorBatched& states) { + Complex coef = Kokkos::polar(1., phase); + Kokkos::parallel_for( + Kokkos::MDRangePolicy>( + {0, 0}, {states.batch_size(), states.dim() >> std::popcount(control_mask)}), + KOKKOS_LAMBDA(std::uint64_t batch_id, std::uint64_t i) { + states._raw(batch_id, insert_zero_at_mask_positions(i, control_mask) | control_mask) *= + coef; + }); + Kokkos::fence(); +} + void x_gate(std::uint64_t target_mask, std::uint64_t control_mask, StateVector& state) { Kokkos::parallel_for( state.dim() >> std::popcount(target_mask | control_mask), KOKKOS_LAMBDA(std::uint64_t it) { @@ -31,6 +47,20 @@ void x_gate(std::uint64_t target_mask, std::uint64_t control_mask, StateVector& }); Kokkos::fence(); } +void x_gate(std::uint64_t target_mask, std::uint64_t control_mask, StateVectorBatched& states) { + Kokkos::parallel_for( + Kokkos::MDRangePolicy>( + {0, 0}, + {states.batch_size(), states.dim() >> std::popcount(target_mask | control_mask)}), + KOKKOS_LAMBDA(std::uint64_t batch_id, std::uint64_t it) { + std::uint64_t i = + insert_zero_at_mask_positions(it, control_mask | target_mask) | control_mask; + Kokkos::Experimental::swap(states._raw(batch_id, i), + states._raw(batch_id, i | target_mask)); + }); + Kokkos::fence(); +} + void y_gate(std::uint64_t target_mask, std::uint64_t control_mask, StateVector& state) { Kokkos::parallel_for( state.dim() >> std::popcount(target_mask | control_mask), KOKKOS_LAMBDA(std::uint64_t it) { @@ -42,6 +72,21 @@ void y_gate(std::uint64_t target_mask, std::uint64_t control_mask, StateVector& }); Kokkos::fence(); } +void y_gate(std::uint64_t target_mask, std::uint64_t control_mask, StateVectorBatched& states) { + Kokkos::parallel_for( + Kokkos::MDRangePolicy>( + {0, 0}, + {states.batch_size(), states.dim() >> std::popcount(target_mask | control_mask)}), + KOKKOS_LAMBDA(std::uint64_t batch_id, std::uint64_t it) { + std::uint64_t i = + insert_zero_at_mask_positions(it, control_mask | target_mask) | control_mask; + states._raw(batch_id, i) *= Complex(0, 1); + states._raw(batch_id, i | target_mask) *= Complex(0, -1); + Kokkos::Experimental::swap(states._raw(batch_id, i), + states._raw(batch_id, i | target_mask)); + }); + Kokkos::fence(); +} void z_gate(std::uint64_t target_mask, std::uint64_t control_mask, StateVector& state) { Kokkos::parallel_for( @@ -52,10 +97,25 @@ void z_gate(std::uint64_t target_mask, std::uint64_t control_mask, StateVector& }); Kokkos::fence(); } +void z_gate(std::uint64_t target_mask, std::uint64_t control_mask, StateVectorBatched& states) { + Kokkos::parallel_for( + Kokkos::MDRangePolicy>( + {0, 0}, + {states.batch_size(), states.dim() >> std::popcount(target_mask | control_mask)}), + KOKKOS_LAMBDA(std::uint64_t batch_id, std::uint64_t it) { + std::uint64_t i = + insert_zero_at_mask_positions(it, control_mask | target_mask) | control_mask; + states._raw(batch_id, i | target_mask) *= Complex(-1, 0); + }); + Kokkos::fence(); +} void h_gate(std::uint64_t target_mask, std::uint64_t control_mask, StateVector& state) { one_target_dense_matrix_gate(target_mask, control_mask, HADAMARD_MATRIX(), state); } +void h_gate(std::uint64_t target_mask, std::uint64_t control_mask, StateVectorBatched& states) { + one_target_dense_matrix_gate(target_mask, control_mask, HADAMARD_MATRIX(), states); +} void one_target_phase_gate(std::uint64_t target_mask, std::uint64_t control_mask, @@ -69,48 +129,99 @@ void one_target_phase_gate(std::uint64_t target_mask, }); Kokkos::fence(); } +void one_target_phase_gate(std::uint64_t target_mask, + std::uint64_t control_mask, + Complex phase, + StateVectorBatched& states) { + Kokkos::parallel_for( + Kokkos::MDRangePolicy>( + {0, 0}, + {states.batch_size(), states.dim() >> std::popcount(target_mask | control_mask)}), + KOKKOS_LAMBDA(std::uint64_t batch_id, std::uint64_t it) { + std::uint64_t i = + insert_zero_at_mask_positions(it, control_mask | target_mask) | control_mask; + states._raw(batch_id, i | target_mask) *= phase; + }); + Kokkos::fence(); +} void s_gate(std::uint64_t target_mask, std::uint64_t control_mask, StateVector& state) { one_target_phase_gate(target_mask, control_mask, Complex(0, 1), state); } +void s_gate(std::uint64_t target_mask, std::uint64_t control_mask, StateVectorBatched& states) { + one_target_phase_gate(target_mask, control_mask, Complex(0, 1), states); +} void sdag_gate(std::uint64_t target_mask, std::uint64_t control_mask, StateVector& state) { one_target_phase_gate(target_mask, control_mask, Complex(0, -1), state); } +void sdag_gate(std::uint64_t target_mask, std::uint64_t control_mask, StateVectorBatched& states) { + one_target_phase_gate(target_mask, control_mask, Complex(0, -1), states); +} void t_gate(std::uint64_t target_mask, std::uint64_t control_mask, StateVector& state) { one_target_phase_gate( target_mask, control_mask, Complex(INVERSE_SQRT2(), INVERSE_SQRT2()), state); } +void t_gate(std::uint64_t target_mask, std::uint64_t control_mask, StateVectorBatched& states) { + one_target_phase_gate( + target_mask, control_mask, Complex(INVERSE_SQRT2(), INVERSE_SQRT2()), states); +} void tdag_gate(std::uint64_t target_mask, std::uint64_t control_mask, StateVector& state) { one_target_phase_gate( target_mask, control_mask, Complex(INVERSE_SQRT2(), -INVERSE_SQRT2()), state); } +void tdag_gate(std::uint64_t target_mask, std::uint64_t control_mask, StateVectorBatched& states) { + one_target_phase_gate( + target_mask, control_mask, Complex(INVERSE_SQRT2(), -INVERSE_SQRT2()), states); +} void sqrtx_gate(std::uint64_t target_mask, std::uint64_t control_mask, StateVector& state) { one_target_dense_matrix_gate(target_mask, control_mask, SQRT_X_GATE_MATRIX(), state); } +void sqrtx_gate(std::uint64_t target_mask, std::uint64_t control_mask, StateVectorBatched& states) { + one_target_dense_matrix_gate(target_mask, control_mask, SQRT_X_GATE_MATRIX(), states); +} void sqrtxdag_gate(std::uint64_t target_mask, std::uint64_t control_mask, StateVector& state) { one_target_dense_matrix_gate(target_mask, control_mask, SQRT_X_DAG_GATE_MATRIX(), state); } +void sqrtxdag_gate(std::uint64_t target_mask, + std::uint64_t control_mask, + StateVectorBatched& states) { + one_target_dense_matrix_gate(target_mask, control_mask, SQRT_X_DAG_GATE_MATRIX(), states); +} void sqrty_gate(std::uint64_t target_mask, std::uint64_t control_mask, StateVector& state) { one_target_dense_matrix_gate(target_mask, control_mask, SQRT_Y_GATE_MATRIX(), state); } +void sqrty_gate(std::uint64_t target_mask, std::uint64_t control_mask, StateVectorBatched& states) { + one_target_dense_matrix_gate(target_mask, control_mask, SQRT_Y_GATE_MATRIX(), states); +} void sqrtydag_gate(std::uint64_t target_mask, std::uint64_t control_mask, StateVector& state) { one_target_dense_matrix_gate(target_mask, control_mask, SQRT_Y_DAG_GATE_MATRIX(), state); } +void sqrtydag_gate(std::uint64_t target_mask, + std::uint64_t control_mask, + StateVectorBatched& states) { + one_target_dense_matrix_gate(target_mask, control_mask, SQRT_Y_DAG_GATE_MATRIX(), states); +} void p0_gate(std::uint64_t target_mask, std::uint64_t control_mask, StateVector& state) { one_target_dense_matrix_gate(target_mask, control_mask, PROJ_0_MATRIX(), state); } +void p0_gate(std::uint64_t target_mask, std::uint64_t control_mask, StateVectorBatched& states) { + one_target_dense_matrix_gate(target_mask, control_mask, PROJ_0_MATRIX(), states); +} void p1_gate(std::uint64_t target_mask, std::uint64_t control_mask, StateVector& state) { one_target_dense_matrix_gate(target_mask, control_mask, PROJ_1_MATRIX(), state); } +void p1_gate(std::uint64_t target_mask, std::uint64_t control_mask, StateVectorBatched& states) { + one_target_dense_matrix_gate(target_mask, control_mask, PROJ_1_MATRIX(), states); +} void rx_gate(std::uint64_t target_mask, std::uint64_t control_mask, @@ -121,6 +232,15 @@ void rx_gate(std::uint64_t target_mask, Matrix2x2 matrix = {cosval, Complex(0, -sinval), Complex(0, -sinval), cosval}; one_target_dense_matrix_gate(target_mask, control_mask, matrix, state); } +void rx_gate(std::uint64_t target_mask, + std::uint64_t control_mask, + double angle, + StateVectorBatched& states) { + const double cosval = std::cos(angle / 2.); + const double sinval = std::sin(angle / 2.); + Matrix2x2 matrix = {cosval, Complex(0, -sinval), Complex(0, -sinval), cosval}; + one_target_dense_matrix_gate(target_mask, control_mask, matrix, states); +} void ry_gate(std::uint64_t target_mask, std::uint64_t control_mask, @@ -131,6 +251,15 @@ void ry_gate(std::uint64_t target_mask, Matrix2x2 matrix = {cosval, -sinval, sinval, cosval}; one_target_dense_matrix_gate(target_mask, control_mask, matrix, state); } +void ry_gate(std::uint64_t target_mask, + std::uint64_t control_mask, + double angle, + StateVectorBatched& states) { + const double cosval = std::cos(angle / 2.); + const double sinval = std::sin(angle / 2.); + Matrix2x2 matrix = {cosval, -sinval, sinval, cosval}; + one_target_dense_matrix_gate(target_mask, control_mask, matrix, states); +} void one_target_diagonal_matrix_gate(std::uint64_t target_mask, std::uint64_t control_mask, @@ -145,6 +274,22 @@ void one_target_diagonal_matrix_gate(std::uint64_t target_mask, }); Kokkos::fence(); } +void one_target_diagonal_matrix_gate(std::uint64_t target_mask, + std::uint64_t control_mask, + const DiagonalMatrix2x2& diag, + StateVectorBatched& states) { + Kokkos::parallel_for( + Kokkos::MDRangePolicy>( + {0, 0}, + {states.batch_size(), states.dim() >> std::popcount(target_mask | control_mask)}), + KOKKOS_LAMBDA(std::uint64_t batch_id, std::uint64_t it) { + std::uint64_t basis = + insert_zero_at_mask_positions(it, target_mask | control_mask) | control_mask; + states._raw(batch_id, basis) *= diag[0]; + states._raw(batch_id, basis | target_mask) *= diag[1]; + }); + Kokkos::fence(); +} void rz_gate(std::uint64_t target_mask, std::uint64_t control_mask, @@ -155,6 +300,15 @@ void rz_gate(std::uint64_t target_mask, DiagonalMatrix2x2 diag = {Complex(cosval, -sinval), Complex(cosval, sinval)}; one_target_diagonal_matrix_gate(target_mask, control_mask, diag, state); } +void rz_gate(std::uint64_t target_mask, + std::uint64_t control_mask, + double angle, + StateVectorBatched& states) { + const double cosval = std::cos(angle / 2.); + const double sinval = std::sin(angle / 2.); + DiagonalMatrix2x2 diag = {Complex(cosval, -sinval), Complex(cosval, sinval)}; + one_target_diagonal_matrix_gate(target_mask, control_mask, diag, states); +} Matrix2x2 get_IBMQ_matrix(double theta, double phi, double lambda) { Complex exp_val1 = Kokkos::exp(Complex(0, phi)); @@ -179,6 +333,23 @@ void u1_gate(std::uint64_t target_mask, }); Kokkos::fence(); } +void u1_gate(std::uint64_t target_mask, + std::uint64_t control_mask, + double lambda, + StateVectorBatched& states) { + Complex exp_val = Kokkos::exp(Complex(0, lambda)); + Kokkos::parallel_for( + Kokkos::MDRangePolicy>( + {0, 0}, + {states.batch_size(), states.dim() >> std::popcount(target_mask | control_mask)}), + KOKKOS_LAMBDA(std::uint64_t batch_id, std::uint64_t it) { + std::uint64_t i = + internal::insert_zero_at_mask_positions(it, target_mask | control_mask) | + control_mask; + states._raw(batch_id, i | target_mask) *= exp_val; + }); + Kokkos::fence(); +} void u2_gate(std::uint64_t target_mask, std::uint64_t control_mask, @@ -188,6 +359,14 @@ void u2_gate(std::uint64_t target_mask, one_target_dense_matrix_gate( target_mask, control_mask, get_IBMQ_matrix(Kokkos::numbers::pi / 2., phi, lambda), state); } +void u2_gate(std::uint64_t target_mask, + std::uint64_t control_mask, + double phi, + double lambda, + StateVectorBatched& states) { + one_target_dense_matrix_gate( + target_mask, control_mask, get_IBMQ_matrix(Kokkos::numbers::pi / 2., phi, lambda), states); +} void u3_gate(std::uint64_t target_mask, std::uint64_t control_mask, @@ -198,6 +377,15 @@ void u3_gate(std::uint64_t target_mask, one_target_dense_matrix_gate( target_mask, control_mask, get_IBMQ_matrix(theta, phi, lambda), state); } +void u3_gate(std::uint64_t target_mask, + std::uint64_t control_mask, + double theta, + double phi, + double lambda, + StateVectorBatched& states) { + one_target_dense_matrix_gate( + target_mask, control_mask, get_IBMQ_matrix(theta, phi, lambda), states); +} void swap_gate(std::uint64_t target_mask, std::uint64_t control_mask, StateVector& state) { // '- target' is used for bit manipulation on unsigned type, not for its numerical meaning. @@ -212,6 +400,21 @@ void swap_gate(std::uint64_t target_mask, std::uint64_t control_mask, StateVecto }); Kokkos::fence(); } +void swap_gate(std::uint64_t target_mask, std::uint64_t control_mask, StateVectorBatched& states) { + std::uint64_t lower_target_mask = target_mask & -target_mask; + std::uint64_t upper_target_mask = target_mask ^ lower_target_mask; + Kokkos::parallel_for( + Kokkos::MDRangePolicy>( + {0, 0}, + {states.batch_size(), states.dim() >> std::popcount(target_mask | control_mask)}), + KOKKOS_LAMBDA(std::uint64_t batch_id, std::uint64_t it) { + std::uint64_t basis = + insert_zero_at_mask_positions(it, target_mask | control_mask) | control_mask; + Kokkos::Experimental::swap(states._raw(batch_id, basis | lower_target_mask), + states._raw(batch_id, basis | upper_target_mask)); + }); + Kokkos::fence(); +} } // namespace internal } // namespace scaluq diff --git a/scaluq/operator/apply_pauli.cpp b/scaluq/operator/apply_pauli.cpp index 27549754..a0266d6c 100644 --- a/scaluq/operator/apply_pauli.cpp +++ b/scaluq/operator/apply_pauli.cpp @@ -43,6 +43,47 @@ void apply_pauli(std::uint64_t control_mask, }); Kokkos::fence(); } +void apply_pauli(std::uint64_t control_mask, + std::uint64_t bit_flip_mask, + std::uint64_t phase_flip_mask, + Complex coef, + StateVectorBatched& states) { + if (bit_flip_mask == 0) { + Kokkos::parallel_for( + Kokkos::MDRangePolicy>( + {0, 0}, {states.batch_size(), states.dim() >> std::popcount(control_mask)}), + KOKKOS_LAMBDA(const std::uint64_t batch_id, const std::uint64_t i) { + std::uint64_t state_idx = + insert_zero_at_mask_positions(i, control_mask) | control_mask; + if (Kokkos::popcount(state_idx & phase_flip_mask) & 1) { + states._raw(batch_id, state_idx) *= -coef; + } else { + states._raw(batch_id, state_idx) *= coef; + } + }); + Kokkos::fence(); + return; + } + std::uint64_t pivot = sizeof(std::uint64_t) * 8 - std::countl_zero(bit_flip_mask) - 1; + std::uint64_t global_phase_90rot_count = std::popcount(bit_flip_mask & phase_flip_mask); + Complex global_phase = PHASE_M90ROT()[global_phase_90rot_count % 4]; + Kokkos::parallel_for( + Kokkos::MDRangePolicy>( + {0, 0}, {states.batch_size(), states.dim() >> (std::popcount(control_mask) + 1)}), + KOKKOS_LAMBDA(const std::uint64_t batch_id, const std::uint64_t i) { + std::uint64_t basis_0 = + insert_zero_at_mask_positions(i, control_mask | 1ULL << pivot) | control_mask; + std::uint64_t basis_1 = basis_0 ^ bit_flip_mask; + Complex tmp1 = states._raw(batch_id, basis_0) * global_phase; + Complex tmp2 = states._raw(batch_id, basis_1) * global_phase; + if (Kokkos::popcount(basis_0 & phase_flip_mask) & 1) tmp2 = -tmp2; + if (Kokkos::popcount(basis_1 & phase_flip_mask) & 1) tmp1 = -tmp1; + states._raw(batch_id, basis_0) = tmp2 * coef; + states._raw(batch_id, basis_1) = tmp1 * coef; + }); + Kokkos::fence(); +} + void apply_pauli_rotation(std::uint64_t control_mask, std::uint64_t bit_flip_mask, std::uint64_t phase_flip_mask, @@ -98,4 +139,62 @@ void apply_pauli_rotation(std::uint64_t control_mask, Kokkos::fence(); } } +void apply_pauli_rotation(std::uint64_t control_mask, + std::uint64_t bit_flip_mask, + std::uint64_t phase_flip_mask, + Complex coef, + double angle, + StateVectorBatched& states) { + std::uint64_t global_phase_90_rot_count = std::popcount(bit_flip_mask & phase_flip_mask); + Complex true_angle = angle * coef; + const Complex cosval = Kokkos::cos(-true_angle / 2); + const Complex sinval = Kokkos::sin(-true_angle / 2); + if (bit_flip_mask == 0) { + const Complex cval_min = cosval - Complex(0, 1) * sinval; + const Complex cval_pls = cosval + Complex(0, 1) * sinval; + Kokkos::parallel_for( + Kokkos::MDRangePolicy>( + {0, 0}, {states.batch_size(), states.dim() >> std::popcount(control_mask)}), + KOKKOS_LAMBDA(const std::uint64_t batch_id, const std::uint64_t i) { + std::uint64_t state_idx = + insert_zero_at_mask_positions(i, control_mask) | control_mask; + if (Kokkos::popcount(state_idx & phase_flip_mask) & 1) { + states._raw(batch_id, state_idx) *= cval_min; + } else { + states._raw(batch_id, state_idx) *= cval_pls; + } + }); + Kokkos::fence(); + return; + } else { + std::uint64_t pivot = sizeof(std::uint64_t) * 8 - std::countl_zero(bit_flip_mask) - 1; + Kokkos::parallel_for( + Kokkos::MDRangePolicy>( + {0, 0}, {states.batch_size(), states.dim() >> (std::popcount(control_mask) + 1)}), + KOKKOS_LAMBDA(const std::uint64_t batch_id, const std::uint64_t i) { + std::uint64_t basis_0 = + internal::insert_zero_at_mask_positions(i, control_mask | 1ULL << pivot) | + control_mask; + std::uint64_t basis_1 = basis_0 ^ bit_flip_mask; + + int bit_parity_0 = Kokkos::popcount(basis_0 & phase_flip_mask) & 1; + int bit_parity_1 = Kokkos::popcount(basis_1 & phase_flip_mask) & 1; + + // fetch values + Complex cval_0 = states._raw(batch_id, basis_0); + Complex cval_1 = states._raw(batch_id, basis_1); + + // set values + states._raw(batch_id, basis_0) = + cosval * cval_0 + + Complex(0, 1) * sinval * cval_1 * + PHASE_M90ROT()[(global_phase_90_rot_count + bit_parity_0 * 2) % 4]; + states._raw(batch_id, basis_1) = + cosval * cval_1 + + Complex(0, 1) * sinval * cval_0 * + PHASE_M90ROT()[(global_phase_90_rot_count + bit_parity_1 * 2) % 4]; + }); + Kokkos::fence(); + } +} } // namespace scaluq::internal diff --git a/scaluq/operator/apply_pauli.hpp b/scaluq/operator/apply_pauli.hpp index ceb41ab5..367f9304 100644 --- a/scaluq/operator/apply_pauli.hpp +++ b/scaluq/operator/apply_pauli.hpp @@ -1,6 +1,7 @@ #pragma once #include "../state/state_vector.hpp" +#include "../state/state_vector_batched.hpp" namespace scaluq::internal { void apply_pauli(std::uint64_t control_mask, @@ -8,10 +9,21 @@ void apply_pauli(std::uint64_t control_mask, std::uint64_t phase_flip_mask, Complex coef, StateVector& state_vector); +void apply_pauli(std::uint64_t control_mask, + std::uint64_t bit_flip_mask, + std::uint64_t phase_flip_mask, + Complex coef, + StateVectorBatched& states); void apply_pauli_rotation(std::uint64_t control_mask, std::uint64_t bit_flip_mask, std::uint64_t phase_flip_mask, Complex coef, double angle, StateVector& state_vector); +void apply_pauli_rotation(std::uint64_t control_mask, + std::uint64_t bit_flip_mask, + std::uint64_t phase_flip_mask, + Complex coef, + double angle, + StateVectorBatched& states); } // namespace scaluq::internal