Skip to content

Commit

Permalink
only float template
Browse files Browse the repository at this point in the history
  • Loading branch information
gandalfr-KY committed Oct 4, 2024
1 parent 05a47a9 commit 736516d
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 58 deletions.
65 changes: 24 additions & 41 deletions exe/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,35 +14,23 @@ namespace internal {

enum class GateType { Unknown, X };

template <std::floating_point FloatType, typename Space>
class XGateImpl;

template <typename T>
inline constexpr bool lazy_false_v = false;

template <typename T>
constexpr GateType get_gate_type() {
if constexpr (std::is_same_v<T, XGateImpl<float, DefaultSpace>> ||
std::is_same_v<T, XGateImpl<double, DefaultSpace>> ||
std::is_same_v<T, XGateImpl<float, HostSpace>> ||
std::is_same_v<T, XGateImpl<double, HostSpace>>) {
if constexpr (std::is_same_v<T, XGateImpl>) {
return GateType::X;
} else {
static_assert(lazy_false_v<T>, "unknown GateImpl");
}
}

// GateBase テンプレートクラス
template <std::floating_point _FloatType, typename _Space>
class GateBase : public std::enable_shared_from_this<GateBase<_FloatType, _Space>> {
public:
using FloatType = _FloatType;
using Space = _Space;

class GateBase : public std::enable_shared_from_this<GateBase> {
protected:
std::uint64_t _target_mask, _control_mask;

void check_qubit_mask_within_bounds(const StateVector<FloatType, Space>& state_vector) const {
template <std::floating_point FloatType>
void check_qubit_mask_within_bounds(const StateVector<FloatType>& state_vector) const {
std::uint64_t full_mask = (1ULL << state_vector.n_qubits()) - 1;
if ((_target_mask | _control_mask) > full_mask) [[unlikely]] {
throw std::runtime_error(
Expand Down Expand Up @@ -92,22 +80,17 @@ class GateBase : public std::enable_shared_from_this<GateBase<_FloatType, _Space
return _target_mask | _control_mask;
}

virtual void update_quantum_state(StateVector<FloatType, Space>& state_vector) const = 0;
virtual void update_quantum_state(StateVector<double>& state_vector) const = 0;
virtual void update_quantum_state(StateVector<float>& state_vector) const = 0;

[[nodiscard]] virtual std::string to_string(const std::string& indent = "") const = 0;
};

template <typename T>
concept GateImpl = std::derived_from<T, GateBase<typename T::FloatType, typename T::Space>>;
concept GateImpl = std::derived_from<T, GateBase>;

template <GateImpl T>
class GatePtr {
using FloatType = T::FloatType;
using Space = T::Space;

static_assert(std::derived_from<T, GateBase<FloatType, Space>>,
"T must derive from GateBase<FloatType, Space>");

private:
std::shared_ptr<const T> _gate_ptr;
GateType _gate_type;
Expand Down Expand Up @@ -166,13 +149,10 @@ class GatePtr {
}
};

template <std::floating_point FloatType, typename Space>
using Gate = GatePtr<GateBase<FloatType, Space>>;
using Gate = GatePtr<GateBase>;

template <std::floating_point FloatType, typename Space>
void x_gate(std::uint64_t target_mask,
std::uint64_t control_mask,
StateVector<FloatType, Space>& state) {
template <std::floating_point FloatType>
void x_gate(std::uint64_t target_mask, std::uint64_t control_mask, StateVector<FloatType>& state) {
Kokkos::parallel_for(
state.dim() >> std::popcount(target_mask | control_mask), KOKKOS_LAMBDA(std::uint64_t it) {
std::uint64_t i =
Expand All @@ -182,12 +162,16 @@ void x_gate(std::uint64_t target_mask,
Kokkos::fence();
}

template <std::floating_point FloatType, typename Space>
class XGateImpl : public GateBase<FloatType, Space> {
class XGateImpl : public GateBase {
public:
using GateBase<FloatType, Space>::GateBase;
using GateBase::GateBase;

void update_quantum_state(StateVector<double>& state_vector) const override {
this->check_qubit_mask_within_bounds(state_vector);
x_gate(this->_target_mask, this->_control_mask, state_vector);
}

void update_quantum_state(StateVector<FloatType, Space>& state_vector) const override {
void update_quantum_state(StateVector<float>& state_vector) const override {
this->check_qubit_mask_within_bounds(state_vector);
x_gate(this->_target_mask, this->_control_mask, state_vector);
}
Expand All @@ -202,7 +186,7 @@ class XGateImpl : public GateBase<FloatType, Space> {
class GateFactory {
public:
template <GateImpl T, typename... Args>
static internal::Gate<typename T::FloatType, typename T::Space> create_gate(Args... args) {
static internal::Gate create_gate(Args... args) {
return {std::make_shared<const T>(args...)};
}
};
Expand All @@ -211,10 +195,9 @@ class GateFactory {

namespace gate {

template <std::floating_point FloatType, typename Space>
inline internal::Gate<FloatType, Space> X(std::uint64_t target,
const std::vector<std::uint64_t>& control_qubits = {}) {
return internal::GateFactory::create_gate<internal::XGateImpl<FloatType, Space>>(
inline internal::Gate X(std::uint64_t target,
const std::vector<std::uint64_t>& control_qubits = {}) {
return internal::GateFactory::create_gate<internal::XGateImpl>(
internal::vector_to_mask({target}), internal::vector_to_mask(control_qubits));
}

Expand All @@ -226,9 +209,9 @@ int main() {
Kokkos::initialize();
{
std::uint64_t n_qubits = 3;
scaluq::StateVector<double, scaluq::HostSpace> state(n_qubits);
scaluq::StateVector<double> state(n_qubits);
state.load({0, 1, 2, 3, 4, 5, 6, 7});
auto x_gate = scaluq::gate::X<double, scaluq::HostSpace>(1, {0, 2});
auto x_gate = scaluq::gate::X(1, {0, 2});
x_gate->update_quantum_state(state);

std::cout << state << std::endl;
Expand Down
16 changes: 9 additions & 7 deletions scaluq/state/state_vector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,23 @@ namespace scaluq {
using HostSpace = Kokkos::HostSpace;
using DefaultSpace = Kokkos::DefaultExecutionSpace;

#define STATE_VECTOR_TEMPLATE(FloatType, Space) \
template <std::floating_point FloatType = double, typename Space = DefaultSpace>
// #define STATE_VECTOR_TEMPLATE(FloatType, Space) \
// template <std::floating_point FloatType = double, typename Space = DefaultSpace>

STATE_VECTOR_TEMPLATE(FloatType, Space)
#define STATE_VECTOR_TEMPLATE(FloatType) template <std::floating_point FloatType = double>

template <std::floating_point FloatType = double>
class StateVector {
std::uint64_t _n_qubits;
std::uint64_t _dim;
using ComplexType = Kokkos::complex<FloatType>;

static_assert(std::is_same_v<Space, HostSpace> || std::is_same_v<Space, DefaultSpace>,
"Unsupported execution space tag");
// static_assert(std::is_same_v<Space, HostSpace> || std::is_same_v<Space, DefaultSpace>,
// "Unsupported execution space tag");

public:
static constexpr std::uint64_t UNMEASURED = 2;
Kokkos::View<ComplexType*, Space> _raw;
Kokkos::View<ComplexType*> _raw;
StateVector() = default;
StateVector(std::uint64_t n_qubits)
: _n_qubits(n_qubits),
Expand Down Expand Up @@ -62,7 +64,7 @@ class StateVector {
[[nodiscard]] static StateVector Haar_random_state(
std::uint64_t n_qubits, std::uint64_t seed = std::random_device()()) {
Kokkos::Random_XorShift64_Pool<> rand_pool(seed);
StateVector<FloatType, Space> state(n_qubits);
StateVector<FloatType> state(n_qubits);
Kokkos::parallel_for(
state._dim, KOKKOS_LAMBDA(std::uint64_t i) {
auto rand_gen = rand_pool.get_state();
Expand Down
19 changes: 9 additions & 10 deletions scaluq/state/state_vector_batched.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@

namespace scaluq {

STATE_VECTOR_TEMPLATE(FloatType, Space)
STATE_VECTOR_TEMPLATE(FloatType)
class StateVectorBatched {
std::uint64_t _batch_size;
std::uint64_t _n_qubits;
std::uint64_t _dim;
using ComplexType = Kokkos::complex<FloatType>;

static_assert(std::is_same_v<Space, HostSpace> || std::is_same_v<Space, DefaultSpace>,
"Unsupported execution space tag");
// static_assert(std::is_same_v<Space, HostSpace> || std::is_same_v<Space, DefaultSpace>,
// "Unsupported execution space tag");

public:
Kokkos::View<ComplexType**, Kokkos::LayoutRight> _raw;
Expand All @@ -36,7 +36,7 @@ class StateVectorBatched {

[[nodiscard]] std::uint64_t batch_size() const { return this->_batch_size; }

void set_state_vector(const StateVector<FloatType, Space>& state) {
void set_state_vector(const StateVector<FloatType>& state) {
if (_raw.extent(1) != state._raw.extent(0)) [[unlikely]] {
throw std::runtime_error(
"Error: StateVectorBatched::set_state_vector(const StateVector&): Dimensions of "
Expand All @@ -50,7 +50,7 @@ class StateVectorBatched {
Kokkos::fence();
}

void set_state_vector_at(std::uint64_t batch_id, const StateVector<FloatType, Space>& state) {
void set_state_vector_at(std::uint64_t batch_id, const StateVector<FloatType>& state) {
if (_raw.extent(1) != state._raw.extent(0)) [[unlikely]] {
throw std::runtime_error(
"Error: StateVectorBatched::set_state_vector(std::uint64_t, const StateVector&): "
Expand All @@ -61,8 +61,8 @@ class StateVectorBatched {
Kokkos::fence();
}

[[nodiscard]] StateVector<FloatType, Space> get_state_vector_at(std::uint64_t batch_id) const {
StateVector<FloatType, Space> ret(_n_qubits);
[[nodiscard]] StateVector<FloatType> get_state_vector_at(std::uint64_t batch_id) const {
StateVector<FloatType> ret(_n_qubits);
Kokkos::parallel_for(
_dim, KOKKOS_CLASS_LAMBDA(std::uint64_t i) { ret._raw(i) = _raw(batch_id, i); });
Kokkos::fence();
Expand Down Expand Up @@ -154,8 +154,7 @@ class StateVectorBatched {
Kokkos::Random_XorShift64_Pool<> rand_pool(seed);
StateVectorBatched states(batch_size, n_qubits);
if (set_same_state) {
states.set_state_vector(
StateVector<FloatType, Space>::Haar_random_state(n_qubits, seed));
states.set_state_vector(StateVector<FloatType>::Haar_random_state(n_qubits, seed));
} else {
Kokkos::parallel_for(
Kokkos::MDRangePolicy<Kokkos::Rank<2>>({0, 0}, {states.batch_size(), states.dim()}),
Expand Down Expand Up @@ -273,7 +272,7 @@ class StateVectorBatched {
if (measured_value == 0 || measured_value == 1) {
target_index.push_back(i);
target_value.push_back(measured_value);
} else if (measured_value != StateVector<FloatType, Space>::UNMEASURED) {
} else if (measured_value != StateVector<FloatType>::UNMEASURED) {
throw std::runtime_error(
"Error:StateVectorBatched::get_marginal_probability(const "
"vector<std::uint64_t>&): Invalid qubit state specified. Each qubit state must "
Expand Down

0 comments on commit 736516d

Please sign in to comment.