Skip to content

Commit

Permalink
reform state_vector_batched_test
Browse files Browse the repository at this point in the history
  • Loading branch information
KowerKoint authored and KowerKoint committed Jan 29, 2025
1 parent 29e4e23 commit 1ab2d23
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 66 deletions.
2 changes: 1 addition & 1 deletion src/state/state_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ std::vector<std::uint64_t> StateVector<Prec>::sampling(std::uint64_t sampling_co
std::vector<std::uint64_t> next_todo;
for (std::size_t i = 0; i < todo_count; i++) {
if (result_buf_host[i] == _dim) {
next_todo.push_back(i);
next_todo.push_back(todo[i]);
} else {
result[todo[i]] = result_buf_host[i];
}
Expand Down
74 changes: 47 additions & 27 deletions src/state/state_vector_batched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,38 +98,58 @@ std::vector<std::vector<std::uint64_t>> StateVectorBatched<Prec>::sampling(
});
Kokkos::fence();

Kokkos::View<std::uint64_t**> result(
Kokkos::ViewAllocateWithoutInitializing("result"), _batch_size, sampling_count);
std::vector result(_batch_size, std::vector<std::uint64_t>(sampling_count));
Kokkos::Random_XorShift64_Pool<> rand_pool(seed);

Kokkos::parallel_for(
Kokkos::MDRangePolicy<Kokkos::Rank<2>>({0, 0}, {_batch_size, sampling_count}),
KOKKOS_CLASS_LAMBDA(std::uint64_t batch_id, std::uint64_t i) {
auto rand_gen = rand_pool.get_state();
FloatType r = static_cast<FloatType>(rand_gen.drand(0., 1.));
std::uint64_t lo = 0, hi = stacked_prob.extent(1);
while (hi - lo > 1) {
std::uint64_t mid = (lo + hi) / 2;
if (stacked_prob(batch_id, mid) > r) {
hi = mid;
} else {
lo = mid;
std::vector<std::uint64_t> batch_todo(_batch_size * sampling_count);
std::vector<std::uint64_t> sample_todo(_batch_size * sampling_count);
for (std::uint64_t i = 0; i < _batch_size; i++) {
for (std::uint64_t j = 0; j < sampling_count; j++) {
std::uint64_t idx = i * sampling_count + j;
batch_todo[idx] = i;
sample_todo[idx] = j;
}
}
while (!batch_todo.empty()) {
std::size_t todo_count = batch_todo.size();
Kokkos::View<std::uint64_t*> batch_ids =
internal::convert_host_vector_to_device_view(batch_todo);
Kokkos::View<std::uint64_t*> result_buf(
Kokkos::ViewAllocateWithoutInitializing("result_buf"), todo_count);
Kokkos::parallel_for(
todo_count, KOKKOS_CLASS_LAMBDA(std::uint64_t idx) {
std::uint64_t batch_id = batch_ids[idx];
auto rand_gen = rand_pool.get_state();
FloatType r = static_cast<FloatType>(rand_gen.drand(0., 1.));
std::uint64_t lo = 0, hi = stacked_prob.extent(1);
while (hi - lo > 1) {
std::uint64_t mid = (lo + hi) / 2;
if (stacked_prob(batch_id, mid) > r) {
hi = mid;
} else {
lo = mid;
}
}
result_buf(idx) = lo;
rand_pool.free_state(rand_gen);
});
Kokkos::fence();
auto result_buf_host = internal::convert_device_view_to_host_vector(result_buf);
// Especially for F16 and BF16, sampling sometimes fails with result == _dim.
// In this case, re-sampling is performed.
std::vector<std::uint64_t> next_batch_todo;
std::vector<std::uint64_t> next_sample_todo;
for (std::size_t i = 0; i < todo_count; i++) {
if (result_buf_host[i] == _dim) {
next_batch_todo.push_back(batch_todo[i]);
next_sample_todo.push_back(sample_todo[i]);
} else {
result[batch_todo[i]][sample_todo[i]] = result_buf_host[i];
}
result(batch_id, i) = lo;
rand_pool.free_state(rand_gen);
});
Kokkos::fence();

auto view_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), result);
std::vector<std::vector<std::uint64_t>> vv(result.extent(0),
std::vector<std::uint64_t>(result.extent(1), 0));
for (size_t i = 0; i < result.extent(0); ++i) {
for (size_t j = 0; j < result.extent(1); ++j) {
vv[i][j] = view_h(i, j);
}
batch_todo.swap(next_batch_todo);
sample_todo.swap(next_sample_todo);
}
return vv;
return result;
}

template <Precision Prec>
Expand Down
2 changes: 1 addition & 1 deletion tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ add_executable(scaluq_test EXCLUDE_FROM_ALL
#operator/test_pauli_operator.cpp
#operator/test_operator.cpp
state/state_vector_test.cpp
#state/state_vector_batched_test.cpp
state/state_vector_batched_test.cpp
)

target_link_libraries(scaluq_test PUBLIC
Expand Down
59 changes: 34 additions & 25 deletions tests/state/state_vector_batched_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,31 @@
#include "../test_environment.hpp"
#include "../util/util.hpp"

using CComplex = std::complex<double>;

using namespace scaluq;

TEST(StateVectorBatchedTest, HaarRandomStateNorm) {
template <typename T>
class StateVectorBatchedTest : public FixtureBase<T> {};
TYPED_TEST_SUITE(StateVectorBatchedTest, TestTypes, NameGenerator);

TYPED_TEST(StateVectorBatchedTest, HaarRandomStateNorm) {
constexpr Precision Prec = TestFixture::Prec;
const std::uint64_t batch_size = 10, n_qubits = 3;
const auto states = StateVectorBatched<double>::Haar_random_state(batch_size, n_qubits, false);
const auto states = StateVectorBatched<Prec>::Haar_random_state(batch_size, n_qubits, false);
auto norms = states.get_squared_norm();
for (auto x : norms) ASSERT_NEAR(x, 1., eps<double>);
for (auto x : norms) ASSERT_NEAR(x, 1., eps<Prec>);
}

TEST(StateVectorBatchedTest, LoadAndAmplitues) {
TYPED_TEST(StateVectorBatchedTest, LoadAndAmplitues) {
constexpr Precision Prec = TestFixture::Prec;
const std::uint64_t batch_size = 4, n_qubits = 3;
const std::uint64_t dim = 1 << n_qubits;
std::vector states_h(batch_size, std::vector<Complex<double>>(dim));
std::vector states_h(batch_size, std::vector<StdComplex>(dim));
for (std::uint64_t b = 0; b < batch_size; ++b) {
for (std::uint64_t i = 0; i < dim; ++i) {
states_h[b][i] = b * dim + i;
}
}
StateVectorBatched<double> states(batch_size, n_qubits);
StateVectorBatched<Prec> states(batch_size, n_qubits);

states.load(states_h);
auto amps = states.get_amplitudes();
Expand All @@ -36,11 +40,12 @@ TEST(StateVectorBatchedTest, LoadAndAmplitues) {
}
}

TEST(StateVectorBatchedTest, OperateState) {
TYPED_TEST(StateVectorBatchedTest, OperateState) {
constexpr Precision Prec = TestFixture::Prec;
const std::uint64_t batch_size = 4, n_qubits = 3;
auto states = StateVectorBatched<double>::Haar_random_state(batch_size, n_qubits, false);
auto states_add = StateVectorBatched<double>::Haar_random_state(batch_size, n_qubits, false);
const Complex<double> coef(2.1, 3.5);
auto states = StateVectorBatched<Prec>::Haar_random_state(batch_size, n_qubits, false);
auto states_add = StateVectorBatched<Prec>::Haar_random_state(batch_size, n_qubits, false);
const StdComplex coef(2.1, 3.5);

auto states_cp = states.copy();
for (std::uint64_t b = 0; b < batch_size; ++b) {
Expand Down Expand Up @@ -71,22 +76,24 @@ TEST(StateVectorBatchedTest, OperateState) {
}
}

TEST(StateVectorBatchedTest, ZeroProbs) {
TYPED_TEST(StateVectorBatchedTest, ZeroProbs) {
constexpr Precision Prec = TestFixture::Prec;
const std::uint64_t batch_size = 4, n_qubits = 3;
auto states = StateVectorBatched<double>::Haar_random_state(batch_size, n_qubits, false);
auto states = StateVectorBatched<Prec>::Haar_random_state(batch_size, n_qubits, false);

for (std::uint64_t i = 0; i < n_qubits; ++i) {
auto zero_probs = states.get_zero_probability(i);
for (std::uint64_t b = 0; b < batch_size; ++b) {
auto state = states.get_state_vector_at(b);
ASSERT_NEAR(zero_probs[b], state.get_zero_probability(i), eps<double>);
ASSERT_NEAR(zero_probs[b], state.get_zero_probability(i), eps<Prec>);
}
}
}

TEST(StateVectorBatchedTest, MarginalProbs) {
TYPED_TEST(StateVectorBatchedTest, MarginalProbs) {
constexpr Precision Prec = TestFixture::Prec;
const std::uint64_t batch_size = 4, n_qubits = 5;
auto states = StateVectorBatched<double>::Haar_random_state(batch_size, n_qubits, false);
auto states = StateVectorBatched<Prec>::Haar_random_state(batch_size, n_qubits, false);

Random rd(0);
for (std::uint64_t i = 0; i < 10; ++i) {
Expand All @@ -97,27 +104,29 @@ TEST(StateVectorBatchedTest, MarginalProbs) {
auto mg_probs = states.get_marginal_probability(targets);
for (std::uint64_t b = 0; b < batch_size; ++b) {
auto state = states.get_state_vector_at(b);
ASSERT_NEAR(mg_probs[b], state.get_marginal_probability(targets), eps<double>);
ASSERT_NEAR(mg_probs[b], state.get_marginal_probability(targets), eps<Prec>);
}
}
}

TEST(StateVectorBatchedTest, Entropy) {
TYPED_TEST(StateVectorBatchedTest, Entropy) {
constexpr Precision Prec = TestFixture::Prec;
const std::uint64_t batch_size = 4, n_qubits = 3;
auto states = StateVectorBatched<double>::Haar_random_state(batch_size, n_qubits, false);
auto states = StateVectorBatched<Prec>::Haar_random_state(batch_size, n_qubits, false);

auto entropies = states.get_entropy();
for (std::uint64_t b = 0; b < batch_size; ++b) {
auto state = states.get_state_vector_at(b);
ASSERT_NEAR(entropies[b], state.get_entropy(), eps<double>);
ASSERT_NEAR(entropies[b], state.get_entropy(), eps<Prec>);
}
}

TEST(StateVectorBatchedTest, Sampling) {
TYPED_TEST(StateVectorBatchedTest, Sampling) {
constexpr Precision Prec = TestFixture::Prec;
const std::uint64_t batch_size = 2, n_qubits = 3;
StateVectorBatched<double> states(batch_size, n_qubits);
states.load(std::vector<std::vector<Complex<double>>>{{1, 4, 5, 0, 0, 0, 0, 0},
{0, 0, 0, 0, 0, 6, 4, 1}});
StateVectorBatched<Prec> states(batch_size, n_qubits);
states.load(
std::vector<std::vector<StdComplex>>{{1, 4, 5, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 6, 4, 1}});
states.normalize();
auto result = states.sampling(4096);
std::vector cnt(2, std::vector<std::uint64_t>(states.dim(), 0));
Expand Down
22 changes: 10 additions & 12 deletions tests/state/state_vector_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
#include "../test_environment.hpp"
#include "../util/util.hpp"

using CComplex = std::complex<double>;

using namespace scaluq;

template <typename T>
Expand Down Expand Up @@ -51,7 +49,7 @@ TYPED_TEST(StateVectorTest, ZeroNormState) {
auto state_cp = state.get_amplitudes();

for (std::uint64_t i = 0; i < state.dim(); ++i) {
ASSERT_EQ((CComplex)state_cp[i], CComplex(0, 0));
ASSERT_EQ((StdComplex)state_cp[i], StdComplex(0, 0));
}
}

Expand All @@ -65,9 +63,9 @@ TYPED_TEST(StateVectorTest, ComputationalBasisState) {

for (std::uint64_t i = 0; i < state.dim(); ++i) {
if (i == 31) {
ASSERT_EQ((CComplex)state_cp[i], CComplex(1, 0));
ASSERT_EQ((StdComplex)state_cp[i], StdComplex(1, 0));
} else {
ASSERT_EQ((CComplex)state_cp[i], CComplex(0, 0));
ASSERT_EQ((StdComplex)state_cp[i], StdComplex(0, 0));
}
}
}
Expand Down Expand Up @@ -103,15 +101,15 @@ TYPED_TEST(StateVectorTest, AddState) {
auto new_vec = state1.get_amplitudes();

for (std::uint64_t i = 0; i < state1.dim(); ++i) {
CComplex res = new_vec[i], val = (CComplex)vec1[i] + (CComplex)vec2[i];
StdComplex res = new_vec[i], val = vec1[i] + vec2[i];
ASSERT_NEAR(res.real(), val.real(), eps<Prec>);
ASSERT_NEAR(res.imag(), val.imag(), eps<Prec>);
}
}

TYPED_TEST(StateVectorTest, AddStateWithCoef) {
constexpr Precision Prec = TestFixture::Prec;
const CComplex coef(2.5, 1.3);
const StdComplex coef(2.5, 1.3);
const std::uint64_t n = 10;
StateVector state1(StateVector<Prec>::Haar_random_state(n));
StateVector state2(StateVector<Prec>::Haar_random_state(n));
Expand All @@ -122,7 +120,7 @@ TYPED_TEST(StateVectorTest, AddStateWithCoef) {
auto new_vec = state1.get_amplitudes();

for (std::uint64_t i = 0; i < state1.dim(); ++i) {
CComplex res = new_vec[i], val = (CComplex)vec1[i] + coef * (CComplex)vec2[i];
StdComplex res = new_vec[i], val = vec1[i] + coef * vec2[i];
ASSERT_NEAR(res.real(), val.real(), eps<Prec>);
ASSERT_NEAR(res.imag(), val.imag(), eps<Prec>);
}
Expand All @@ -131,15 +129,15 @@ TYPED_TEST(StateVectorTest, AddStateWithCoef) {
TYPED_TEST(StateVectorTest, MultiplyCoef) {
constexpr Precision Prec = TestFixture::Prec;
const std::uint64_t n = 10;
const CComplex coef(0.5, 0.2);
const StdComplex coef(0.5, 0.2);

StateVector state(StateVector<Prec>::Haar_random_state(n));
auto vec = state.get_amplitudes();
state.multiply_coef(coef);
auto new_vec = state.get_amplitudes();

for (std::uint64_t i = 0; i < state.dim(); ++i) {
CComplex res = new_vec[i], val = coef * (CComplex)vec[i];
StdComplex res = new_vec[i], val = coef * vec[i];
ASSERT_NEAR(res.real(), val.real(), eps<Prec>);
ASSERT_NEAR(res.imag(), val.imag(), eps<Prec>);
}
Expand Down Expand Up @@ -174,12 +172,12 @@ TYPED_TEST(StateVectorTest, EntropyCalculation) {
auto state_cp = state.get_amplitudes();
ASSERT_NEAR(state.get_squared_norm(), 1, eps<Prec>);
Eigen::VectorXcd test_state(dim);
for (std::uint64_t i = 0; i < dim; ++i) test_state[i] = (CComplex)state_cp[i];
for (std::uint64_t i = 0; i < dim; ++i) test_state[i] = state_cp[i];

for (std::uint64_t target = 0; target < n; ++target) {
double ent = 0;
for (std::uint64_t ind = 0; ind < dim; ++ind) {
CComplex z = test_state[ind];
StdComplex z = test_state[ind];
double prob = z.real() * z.real() + z.imag() * z.imag();
if (prob > 0.) ent += -prob * std::log2(prob);
}
Expand Down

0 comments on commit 1ab2d23

Please sign in to comment.