Skip to content

Commit

Permalink
add batch update
Browse files Browse the repository at this point in the history
  • Loading branch information
Glacialte committed Sep 18, 2024
1 parent 8cc26cc commit 6d73ccc
Show file tree
Hide file tree
Showing 12 changed files with 530 additions and 47 deletions.
13 changes: 12 additions & 1 deletion scaluq/gate/gate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,18 @@ class GateBase : public std::enable_shared_from_this<GateBase> {
}
}

std::string get_qubit_info_as_string(const std::string& indent) const {
[[nodiscard]] std::vector<std::uint64_t> mask_to_vector(std::uint64_t mask) const {
std::vector<std::uint64_t> qubits;
for (std::uint64_t i = 0; i < 64; ++i) {
if ((mask >> i) & 1) qubits.push_back(i);
}
return qubits;
}

[[nodiscard]]

std::string
get_qubit_info_as_string(const std::string& indent) const {
std::ostringstream ss;
auto targets = target_qubit_list();
auto controls = control_qubit_list();
Expand Down
10 changes: 8 additions & 2 deletions scaluq/gate/gate_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@ class OneTargetMatrixGateImpl : public GateBase {
one_target_dense_matrix_gate(_target_mask, _control_mask, _matrix, state_vector);
}

void update_quantum_state(StateVectorBatched& states) const override {}
void update_quantum_state(StateVectorBatched& states) const override {
check_qubit_mask_within_bounds(states.get_state_vector_at(0));
one_target_dense_matrix_gate(_target_mask, _control_mask, _matrix, states);
}

std::string to_string(const std::string& indent) const override {
std::ostringstream ss;
Expand Down Expand Up @@ -107,7 +110,10 @@ class TwoTargetMatrixGateImpl : public GateBase {
two_target_dense_matrix_gate(_target_mask, _control_mask, _matrix, state_vector);
}

void update_quantum_state(StateVectorBatched& states) const override {}
void update_quantum_state(StateVectorBatched& states) const override {
check_qubit_mask_within_bounds(states.get_state_vector_at(0));
two_target_dense_matrix_gate(_target_mask, _control_mask, _matrix, states);
}

std::string to_string(const std::string& indent) const override {
std::ostringstream ss;
Expand Down
11 changes: 9 additions & 2 deletions scaluq/gate/gate_pauli.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ class PauliGateImpl : public GateBase {
apply_pauli(_control_mask, bit_flip_mask, phase_flip_mask, _pauli.coef(), state_vector);
}

void update_quantum_state(StateVectorBatched& states) const override {}
void update_quantum_state(StateVectorBatched& states) const override {
auto [bit_flip_mask, phase_flip_mask] = _pauli.get_XZ_mask_representation();
apply_pauli(_control_mask, bit_flip_mask, phase_flip_mask, _pauli.coef(), states);
}

std::string to_string(const std::string& indent) const override {
std::ostringstream ss;
Expand Down Expand Up @@ -75,7 +78,11 @@ class PauliRotationGateImpl : public GateBase {
_control_mask, bit_flip_mask, phase_flip_mask, _pauli.coef(), _angle, state_vector);
}

void update_quantum_state(StateVectorBatched& states) const override {}
void update_quantum_state(StateVectorBatched& states) const override {
auto [bit_flip_mask, phase_flip_mask] = _pauli.get_XZ_mask_representation();
apply_pauli_rotation(
_control_mask, bit_flip_mask, phase_flip_mask, _pauli.coef(), _angle, states);
}

std::string to_string(const std::string& indent) const override {
std::ostringstream ss;
Expand Down
33 changes: 26 additions & 7 deletions scaluq/gate/gate_probablistic.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace scaluq {
namespace internal {
class ProbablisticGateImpl : public GateBase {
std::vector<double> _distribution;
std::vector<double> _cumlative_distribution;
std::vector<double> _cumulative_distribution;
std::vector<Gate> _gate_list;

public:
Expand All @@ -21,10 +21,10 @@ class ProbablisticGateImpl : public GateBase {
if (n != gate_list.size()) {
throw std::runtime_error("distribution and gate_list have different size.");
}
_cumlative_distribution.resize(n + 1);
_cumulative_distribution.resize(n + 1);
std::partial_sum(
distribution.begin(), distribution.end(), _cumlative_distribution.begin() + 1);
if (std::abs(_cumlative_distribution.back() - 1.) > 1e-6) {
distribution.begin(), distribution.end(), _cumulative_distribution.begin() + 1);
if (std::abs(_cumulative_distribution.back() - 1.) > 1e-6) {
throw std::runtime_error("Sum of distribution must be equal to 1.");
}
}
Expand Down Expand Up @@ -79,14 +79,33 @@ class ProbablisticGateImpl : public GateBase {
void update_quantum_state(StateVector& state_vector) const override {
Random random;
double r = random.uniform();
std::uint64_t i = std::distance(_cumlative_distribution.begin(),
std::ranges::upper_bound(_cumlative_distribution, r)) -
std::uint64_t i = std::distance(_cumulative_distribution.begin(),
std::ranges::upper_bound(_cumulative_distribution, r)) -
1;
if (i >= _gate_list.size()) i = _gate_list.size() - 1;
_gate_list[i]->update_quantum_state(state_vector);
}

void update_quantum_state(StateVectorBatched& states) const override {}
// これでいいのか?わからない...
void update_quantum_state(StateVectorBatched& states) const override {
Random random;
std::vector<double> r(states.batch_size());
std::ranges::generate(r, [&random]() { return random.uniform(); });
std::vector<std::uint64_t> indicies(states.batch_size());
std::ranges::transform(r, indicies.begin(), [this](double r) {
return std::distance(_cumulative_distribution.begin(),
std::ranges::upper_bound(_cumulative_distribution, r)) -
1;
});
std::ranges::transform(indicies, indicies.begin(), [this](std::uint64_t i) {
if (i >= _gate_list.size()) i = _gate_list.size() - 1;
return i;
});
for (std::size_t i = 0; i < states.batch_size(); ++i) {
auto state_vector = states.get_state_vector_at(i);
_gate_list[indicies[i]]->update_quantum_state(state_vector);
}
}

std::string to_string(const std::string& indent) const override {
std::ostringstream ss;
Expand Down
Loading

0 comments on commit 6d73ccc

Please sign in to comment.