diff --git a/tests/gate/batched_gate_test.cpp b/tests/gate/batched_gate_test.cpp index 9131b1ce..4392237a 100644 --- a/tests/gate/batched_gate_test.cpp +++ b/tests/gate/batched_gate_test.cpp @@ -315,7 +315,8 @@ void run_random_batched_gate_apply_pauli(std::uint64_t n_qubits) { for (std::uint64_t batch_id = 0; batch_id < states.batch_size(); batch_id++) { for (std::uint64_t i = 0; i < dim; i++) { - check_near((StdComplex)states_cp[batch_id][i], test_state[i]); + check_near((StdComplex)states_cp[batch_id][i], + (StdComplex)states_bef_cp[batch_id][i]); } } } @@ -379,7 +380,8 @@ void run_random_batched_gate_apply_pauli(std::uint64_t n_qubits) { for (std::uint64_t batch_id = 0; batch_id < states.batch_size(); batch_id++) { for (std::uint64_t i = 0; i < dim; i++) { - check_near((StdComplex)states_cp[batch_id][i], test_state[i]); + check_near((StdComplex)states_cp[batch_id][i], + (StdComplex)states_bef_cp[batch_id][i]); } } } @@ -713,10 +715,10 @@ TEST(BatchedGateTest, ApplyDenseMatrixGate) { run_random_batched_gate_apply_general_dense(6); run_random_batched_gate_apply_general_dense(6); } -// TEST(BatchedGateTest, ApplyPauliGate) { -// run_random_batched_gate_apply_pauli(5); -// run_random_batched_gate_apply_pauli(5); -// } +TEST(BatchedGateTest, ApplyPauliGate) { + run_random_batched_gate_apply_pauli(5); + run_random_batched_gate_apply_pauli(5); +} TEST(BatchedGateTest, ApplyProbablisticGate) { { @@ -778,24 +780,24 @@ void test_batched_gate(Gate gate_control, StateVectorBatched states = StateVectorBatched::Haar_random_state(BATCH_SIZE, n_qubits, true); auto amplitudes = states.get_amplitudes(); - StateVectorBatched state_controlled(BATCH_SIZE, n_qubits - std::popcount(control_mask)); + StateVectorBatched states_controlled(BATCH_SIZE, n_qubits - std::popcount(control_mask)); std::vector>> amplitudes_controlled( - BATCH_SIZE, std::vector>(state_controlled.dim())); + BATCH_SIZE, std::vector>(states_controlled.dim())); for (std::size_t i = 0; i < BATCH_SIZE; i++) { - for (std::uint64_t j = 0; j < state_controlled.dim(); j++) { + for (std::uint64_t j = 0; j < states_controlled.dim(); j++) { amplitudes_controlled[i][j] = amplitudes[i] [internal::insert_zero_at_mask_positions(j, control_mask) | control_mask]; } } - state_controlled.load(amplitudes_controlled); + states_controlled.load(amplitudes_controlled); gate_control->update_quantum_state(states); - gate_simple->update_quantum_state(state_controlled); + gate_simple->update_quantum_state(states_controlled); amplitudes = states.get_amplitudes(); - amplitudes_controlled = state_controlled.get_amplitudes(); - for (std::size_t i = 0; i < BATCH_SIZE; i++) { - for (std::uint64_t j = 0; j < state_controlled.dim(); j++) { - check_near((StdComplex)amplitudes[i][j], + amplitudes_controlled = states_controlled.get_amplitudes(); + for (std::uint64_t i = 0; i < BATCH_SIZE; i++) { + for (std::uint64_t j : std::views::iota(0ULL, states_controlled.dim())) { + check_near((StdComplex)amplitudes_controlled[i][j], (StdComplex) amplitudes[i][internal::insert_zero_at_mask_positions(j, control_mask) | control_mask]); @@ -809,12 +811,7 @@ template void test_batched_standard_gate_control(Factory factory, std::uint64_t n) { Random random; - std::vector shuffled(n); - std::iota(shuffled.begin(), shuffled.end(), 0ULL); - for (std::uint64_t i : std::views::iota(0ULL, n) | std::views::reverse) { - std::uint64_t j = random.int32() % (i + 1); - if (i != j) std::swap(shuffled[i], shuffled[j]); - } + std::vector shuffled = random.permutation(n); std::vector targets(num_target); for (std::uint64_t i : std::views::iota(0ULL, num_target)) { targets[i] = shuffled[i]; @@ -1018,8 +1015,8 @@ TEST(BatchGateTest, Control) { test_batched_standard_gate_control(gate::U2, n); test_batched_standard_gate_control(gate::U3, n); test_batched_standard_gate_control(gate::Swap, n); - // test_batched_pauli_control(n); - // test_batched_pauli_control(n); + test_batched_pauli_control(n); + test_batched_pauli_control(n); test_batched_matrix_control(n); test_batched_matrix_control(n); test_batched_matrix_control(n); @@ -1046,8 +1043,8 @@ TEST(BatchGateTest, Control) { test_batched_standard_gate_control(gate::U2, n); test_batched_standard_gate_control(gate::U3, n); test_batched_standard_gate_control(gate::Swap, n); - // test_batched_pauli_control(n); - // test_batched_pauli_control(n); + test_batched_pauli_control(n); + test_batched_pauli_control(n); test_batched_matrix_control(n); test_batched_matrix_control(n); test_batched_matrix_control(n);