diff --git a/src/state/state_vector.cpp b/src/state/state_vector.cpp index 7fa2c76..4fa5874 100644 --- a/src/state/state_vector.cpp +++ b/src/state/state_vector.cpp @@ -229,7 +229,7 @@ std::vector StateVector::sampling(std::uint64_t sampling_co std::vector next_todo; for (std::size_t i = 0; i < todo_count; i++) { if (result_buf_host[i] == _dim) { - next_todo.push_back(i); + next_todo.push_back(todo[i]); } else { result[todo[i]] = result_buf_host[i]; } diff --git a/src/state/state_vector_batched.cpp b/src/state/state_vector_batched.cpp index 167e5ff..60e609d 100644 --- a/src/state/state_vector_batched.cpp +++ b/src/state/state_vector_batched.cpp @@ -98,38 +98,58 @@ std::vector> StateVectorBatched::sampling( }); Kokkos::fence(); - Kokkos::View result( - Kokkos::ViewAllocateWithoutInitializing("result"), _batch_size, sampling_count); + std::vector result(_batch_size, std::vector(sampling_count)); Kokkos::Random_XorShift64_Pool<> rand_pool(seed); - - Kokkos::parallel_for( - Kokkos::MDRangePolicy>({0, 0}, {_batch_size, sampling_count}), - KOKKOS_CLASS_LAMBDA(std::uint64_t batch_id, std::uint64_t i) { - auto rand_gen = rand_pool.get_state(); - FloatType r = static_cast(rand_gen.drand(0., 1.)); - std::uint64_t lo = 0, hi = stacked_prob.extent(1); - while (hi - lo > 1) { - std::uint64_t mid = (lo + hi) / 2; - if (stacked_prob(batch_id, mid) > r) { - hi = mid; - } else { - lo = mid; + std::vector batch_todo(_batch_size * sampling_count); + std::vector sample_todo(_batch_size * sampling_count); + for (std::uint64_t i = 0; i < _batch_size; i++) { + for (std::uint64_t j = 0; j < sampling_count; j++) { + std::uint64_t idx = i * sampling_count + j; + batch_todo[idx] = i; + sample_todo[idx] = j; + } + } + while (!batch_todo.empty()) { + std::size_t todo_count = batch_todo.size(); + Kokkos::View batch_ids = + internal::convert_host_vector_to_device_view(batch_todo); + Kokkos::View result_buf( + Kokkos::ViewAllocateWithoutInitializing("result_buf"), todo_count); + Kokkos::parallel_for( + todo_count, KOKKOS_CLASS_LAMBDA(std::uint64_t idx) { + std::uint64_t batch_id = batch_ids[idx]; + auto rand_gen = rand_pool.get_state(); + FloatType r = static_cast(rand_gen.drand(0., 1.)); + std::uint64_t lo = 0, hi = stacked_prob.extent(1); + while (hi - lo > 1) { + std::uint64_t mid = (lo + hi) / 2; + if (stacked_prob(batch_id, mid) > r) { + hi = mid; + } else { + lo = mid; + } } + result_buf(idx) = lo; + rand_pool.free_state(rand_gen); + }); + Kokkos::fence(); + auto result_buf_host = internal::convert_device_view_to_host_vector(result_buf); + // Especially for F16 and BF16, sampling sometimes fails with result == _dim. + // In this case, re-sampling is performed. + std::vector next_batch_todo; + std::vector next_sample_todo; + for (std::size_t i = 0; i < todo_count; i++) { + if (result_buf_host[i] == _dim) { + next_batch_todo.push_back(batch_todo[i]); + next_sample_todo.push_back(sample_todo[i]); + } else { + result[batch_todo[i]][sample_todo[i]] = result_buf_host[i]; } - result(batch_id, i) = lo; - rand_pool.free_state(rand_gen); - }); - Kokkos::fence(); - - auto view_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), result); - std::vector> vv(result.extent(0), - std::vector(result.extent(1), 0)); - for (size_t i = 0; i < result.extent(0); ++i) { - for (size_t j = 0; j < result.extent(1); ++j) { - vv[i][j] = view_h(i, j); } + batch_todo.swap(next_batch_todo); + sample_todo.swap(next_sample_todo); } - return vv; + return result; } template diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index e7ff0cd..6279459 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -13,7 +13,7 @@ add_executable(scaluq_test EXCLUDE_FROM_ALL #operator/test_pauli_operator.cpp #operator/test_operator.cpp state/state_vector_test.cpp - #state/state_vector_batched_test.cpp + state/state_vector_batched_test.cpp ) target_link_libraries(scaluq_test PUBLIC diff --git a/tests/state/state_vector_batched_test.cpp b/tests/state/state_vector_batched_test.cpp index ed09adf..8063907 100644 --- a/tests/state/state_vector_batched_test.cpp +++ b/tests/state/state_vector_batched_test.cpp @@ -5,27 +5,31 @@ #include "../test_environment.hpp" #include "../util/util.hpp" -using CComplex = std::complex; - using namespace scaluq; -TEST(StateVectorBatchedTest, HaarRandomStateNorm) { +template +class StateVectorBatchedTest : public FixtureBase {}; +TYPED_TEST_SUITE(StateVectorBatchedTest, TestTypes, NameGenerator); + +TYPED_TEST(StateVectorBatchedTest, HaarRandomStateNorm) { + constexpr Precision Prec = TestFixture::Prec; const std::uint64_t batch_size = 10, n_qubits = 3; - const auto states = StateVectorBatched::Haar_random_state(batch_size, n_qubits, false); + const auto states = StateVectorBatched::Haar_random_state(batch_size, n_qubits, false); auto norms = states.get_squared_norm(); - for (auto x : norms) ASSERT_NEAR(x, 1., eps); + for (auto x : norms) ASSERT_NEAR(x, 1., eps); } -TEST(StateVectorBatchedTest, LoadAndAmplitues) { +TYPED_TEST(StateVectorBatchedTest, LoadAndAmplitues) { + constexpr Precision Prec = TestFixture::Prec; const std::uint64_t batch_size = 4, n_qubits = 3; const std::uint64_t dim = 1 << n_qubits; - std::vector states_h(batch_size, std::vector>(dim)); + std::vector states_h(batch_size, std::vector(dim)); for (std::uint64_t b = 0; b < batch_size; ++b) { for (std::uint64_t i = 0; i < dim; ++i) { states_h[b][i] = b * dim + i; } } - StateVectorBatched states(batch_size, n_qubits); + StateVectorBatched states(batch_size, n_qubits); states.load(states_h); auto amps = states.get_amplitudes(); @@ -36,11 +40,12 @@ TEST(StateVectorBatchedTest, LoadAndAmplitues) { } } -TEST(StateVectorBatchedTest, OperateState) { +TYPED_TEST(StateVectorBatchedTest, OperateState) { + constexpr Precision Prec = TestFixture::Prec; const std::uint64_t batch_size = 4, n_qubits = 3; - auto states = StateVectorBatched::Haar_random_state(batch_size, n_qubits, false); - auto states_add = StateVectorBatched::Haar_random_state(batch_size, n_qubits, false); - const Complex coef(2.1, 3.5); + auto states = StateVectorBatched::Haar_random_state(batch_size, n_qubits, false); + auto states_add = StateVectorBatched::Haar_random_state(batch_size, n_qubits, false); + const StdComplex coef(2.1, 3.5); auto states_cp = states.copy(); for (std::uint64_t b = 0; b < batch_size; ++b) { @@ -71,22 +76,24 @@ TEST(StateVectorBatchedTest, OperateState) { } } -TEST(StateVectorBatchedTest, ZeroProbs) { +TYPED_TEST(StateVectorBatchedTest, ZeroProbs) { + constexpr Precision Prec = TestFixture::Prec; const std::uint64_t batch_size = 4, n_qubits = 3; - auto states = StateVectorBatched::Haar_random_state(batch_size, n_qubits, false); + auto states = StateVectorBatched::Haar_random_state(batch_size, n_qubits, false); for (std::uint64_t i = 0; i < n_qubits; ++i) { auto zero_probs = states.get_zero_probability(i); for (std::uint64_t b = 0; b < batch_size; ++b) { auto state = states.get_state_vector_at(b); - ASSERT_NEAR(zero_probs[b], state.get_zero_probability(i), eps); + ASSERT_NEAR(zero_probs[b], state.get_zero_probability(i), eps); } } } -TEST(StateVectorBatchedTest, MarginalProbs) { +TYPED_TEST(StateVectorBatchedTest, MarginalProbs) { + constexpr Precision Prec = TestFixture::Prec; const std::uint64_t batch_size = 4, n_qubits = 5; - auto states = StateVectorBatched::Haar_random_state(batch_size, n_qubits, false); + auto states = StateVectorBatched::Haar_random_state(batch_size, n_qubits, false); Random rd(0); for (std::uint64_t i = 0; i < 10; ++i) { @@ -97,27 +104,29 @@ TEST(StateVectorBatchedTest, MarginalProbs) { auto mg_probs = states.get_marginal_probability(targets); for (std::uint64_t b = 0; b < batch_size; ++b) { auto state = states.get_state_vector_at(b); - ASSERT_NEAR(mg_probs[b], state.get_marginal_probability(targets), eps); + ASSERT_NEAR(mg_probs[b], state.get_marginal_probability(targets), eps); } } } -TEST(StateVectorBatchedTest, Entropy) { +TYPED_TEST(StateVectorBatchedTest, Entropy) { + constexpr Precision Prec = TestFixture::Prec; const std::uint64_t batch_size = 4, n_qubits = 3; - auto states = StateVectorBatched::Haar_random_state(batch_size, n_qubits, false); + auto states = StateVectorBatched::Haar_random_state(batch_size, n_qubits, false); auto entropies = states.get_entropy(); for (std::uint64_t b = 0; b < batch_size; ++b) { auto state = states.get_state_vector_at(b); - ASSERT_NEAR(entropies[b], state.get_entropy(), eps); + ASSERT_NEAR(entropies[b], state.get_entropy(), eps); } } -TEST(StateVectorBatchedTest, Sampling) { +TYPED_TEST(StateVectorBatchedTest, Sampling) { + constexpr Precision Prec = TestFixture::Prec; const std::uint64_t batch_size = 2, n_qubits = 3; - StateVectorBatched states(batch_size, n_qubits); - states.load(std::vector>>{{1, 4, 5, 0, 0, 0, 0, 0}, - {0, 0, 0, 0, 0, 6, 4, 1}}); + StateVectorBatched states(batch_size, n_qubits); + states.load( + std::vector>{{1, 4, 5, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 6, 4, 1}}); states.normalize(); auto result = states.sampling(4096); std::vector cnt(2, std::vector(states.dim(), 0)); diff --git a/tests/state/state_vector_test.cpp b/tests/state/state_vector_test.cpp index 8c6b74d..8792576 100644 --- a/tests/state/state_vector_test.cpp +++ b/tests/state/state_vector_test.cpp @@ -5,8 +5,6 @@ #include "../test_environment.hpp" #include "../util/util.hpp" -using CComplex = std::complex; - using namespace scaluq; template @@ -51,7 +49,7 @@ TYPED_TEST(StateVectorTest, ZeroNormState) { auto state_cp = state.get_amplitudes(); for (std::uint64_t i = 0; i < state.dim(); ++i) { - ASSERT_EQ((CComplex)state_cp[i], CComplex(0, 0)); + ASSERT_EQ((StdComplex)state_cp[i], StdComplex(0, 0)); } } @@ -65,9 +63,9 @@ TYPED_TEST(StateVectorTest, ComputationalBasisState) { for (std::uint64_t i = 0; i < state.dim(); ++i) { if (i == 31) { - ASSERT_EQ((CComplex)state_cp[i], CComplex(1, 0)); + ASSERT_EQ((StdComplex)state_cp[i], StdComplex(1, 0)); } else { - ASSERT_EQ((CComplex)state_cp[i], CComplex(0, 0)); + ASSERT_EQ((StdComplex)state_cp[i], StdComplex(0, 0)); } } } @@ -103,7 +101,7 @@ TYPED_TEST(StateVectorTest, AddState) { auto new_vec = state1.get_amplitudes(); for (std::uint64_t i = 0; i < state1.dim(); ++i) { - CComplex res = new_vec[i], val = (CComplex)vec1[i] + (CComplex)vec2[i]; + StdComplex res = new_vec[i], val = vec1[i] + vec2[i]; ASSERT_NEAR(res.real(), val.real(), eps); ASSERT_NEAR(res.imag(), val.imag(), eps); } @@ -111,7 +109,7 @@ TYPED_TEST(StateVectorTest, AddState) { TYPED_TEST(StateVectorTest, AddStateWithCoef) { constexpr Precision Prec = TestFixture::Prec; - const CComplex coef(2.5, 1.3); + const StdComplex coef(2.5, 1.3); const std::uint64_t n = 10; StateVector state1(StateVector::Haar_random_state(n)); StateVector state2(StateVector::Haar_random_state(n)); @@ -122,7 +120,7 @@ TYPED_TEST(StateVectorTest, AddStateWithCoef) { auto new_vec = state1.get_amplitudes(); for (std::uint64_t i = 0; i < state1.dim(); ++i) { - CComplex res = new_vec[i], val = (CComplex)vec1[i] + coef * (CComplex)vec2[i]; + StdComplex res = new_vec[i], val = vec1[i] + coef * vec2[i]; ASSERT_NEAR(res.real(), val.real(), eps); ASSERT_NEAR(res.imag(), val.imag(), eps); } @@ -131,7 +129,7 @@ TYPED_TEST(StateVectorTest, AddStateWithCoef) { TYPED_TEST(StateVectorTest, MultiplyCoef) { constexpr Precision Prec = TestFixture::Prec; const std::uint64_t n = 10; - const CComplex coef(0.5, 0.2); + const StdComplex coef(0.5, 0.2); StateVector state(StateVector::Haar_random_state(n)); auto vec = state.get_amplitudes(); @@ -139,7 +137,7 @@ TYPED_TEST(StateVectorTest, MultiplyCoef) { auto new_vec = state.get_amplitudes(); for (std::uint64_t i = 0; i < state.dim(); ++i) { - CComplex res = new_vec[i], val = coef * (CComplex)vec[i]; + StdComplex res = new_vec[i], val = coef * vec[i]; ASSERT_NEAR(res.real(), val.real(), eps); ASSERT_NEAR(res.imag(), val.imag(), eps); } @@ -174,12 +172,12 @@ TYPED_TEST(StateVectorTest, EntropyCalculation) { auto state_cp = state.get_amplitudes(); ASSERT_NEAR(state.get_squared_norm(), 1, eps); Eigen::VectorXcd test_state(dim); - for (std::uint64_t i = 0; i < dim; ++i) test_state[i] = (CComplex)state_cp[i]; + for (std::uint64_t i = 0; i < dim; ++i) test_state[i] = state_cp[i]; for (std::uint64_t target = 0; target < n; ++target) { double ent = 0; for (std::uint64_t ind = 0; ind < dim; ++ind) { - CComplex z = test_state[ind]; + StdComplex z = test_state[ind]; double prob = z.real() * z.real() + z.imag() * z.imag(); if (prob > 0.) ent += -prob * std::log2(prob); }