Skip to content

Commit

Permalink
gate_standard
Browse files Browse the repository at this point in the history
  • Loading branch information
gandalfr-KY committed Oct 4, 2024
1 parent 736516d commit 5228ed3
Show file tree
Hide file tree
Showing 3 changed files with 225 additions and 108 deletions.
45 changes: 24 additions & 21 deletions exe/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,27 @@ namespace internal {

enum class GateType { Unknown, X };

template <std::floating_point FloatType>
class XGateImpl;

template <typename T>
constexpr GateType get_gate_type() {
if constexpr (std::is_same_v<T, XGateImpl>) {
if constexpr (std::is_same_v<T, XGateImpl<float>> || std::is_same_v<T, XGateImpl<double>>) {
return GateType::X;
} else {
static_assert(lazy_false_v<T>, "unknown GateImpl");
}
}

class GateBase : public std::enable_shared_from_this<GateBase> {
// GateBase テンプレートクラス
template <std::floating_point _FloatType>
class GateBase : public std::enable_shared_from_this<GateBase<_FloatType>> {
public:
using FloatType = _FloatType;

protected:
std::uint64_t _target_mask, _control_mask;

template <std::floating_point FloatType>
void check_qubit_mask_within_bounds(const StateVector<FloatType>& state_vector) const {
std::uint64_t full_mask = (1ULL << state_vector.n_qubits()) - 1;
if ((_target_mask | _control_mask) > full_mask) [[unlikely]] {
Expand Down Expand Up @@ -80,17 +85,18 @@ class GateBase : public std::enable_shared_from_this<GateBase> {
return _target_mask | _control_mask;
}

virtual void update_quantum_state(StateVector<double>& state_vector) const = 0;
virtual void update_quantum_state(StateVector<float>& state_vector) const = 0;
virtual void update_quantum_state(StateVector<FloatType>& state_vector) const = 0;

[[nodiscard]] virtual std::string to_string(const std::string& indent = "") const = 0;
};

template <typename T>
concept GateImpl = std::derived_from<T, GateBase>;
concept GateImpl = std::derived_from<T, GateBase<typename T::FloatType>>;

template <GateImpl T>
class GatePtr {
using FloatType = T::FloatType;

private:
std::shared_ptr<const T> _gate_ptr;
GateType _gate_type;
Expand Down Expand Up @@ -149,7 +155,8 @@ class GatePtr {
}
};

using Gate = GatePtr<GateBase>;
template <std::floating_point FloatType>
using Gate = GatePtr<GateBase<FloatType>>;

template <std::floating_point FloatType>
void x_gate(std::uint64_t target_mask, std::uint64_t control_mask, StateVector<FloatType>& state) {
Expand All @@ -162,31 +169,26 @@ void x_gate(std::uint64_t target_mask, std::uint64_t control_mask, StateVector<F
Kokkos::fence();
}

class XGateImpl : public GateBase {
template <std::floating_point FloatType>
class XGateImpl : public GateBase<FloatType> {
public:
using GateBase::GateBase;
using GateBase<FloatType>::GateBase;

void update_quantum_state(StateVector<double>& state_vector) const override {
this->check_qubit_mask_within_bounds(state_vector);
x_gate(this->_target_mask, this->_control_mask, state_vector);
}

void update_quantum_state(StateVector<float>& state_vector) const override {
void update_quantum_state(StateVector<FloatType>& state_vector) const override {
this->check_qubit_mask_within_bounds(state_vector);
x_gate(this->_target_mask, this->_control_mask, state_vector);
}

std::string to_string(const std::string& indent = "") const override {
std::ostringstream ss;
ss << indent << "XGate";
return ss.str();
}
};

class GateFactory {
public:
template <GateImpl T, typename... Args>
static internal::Gate create_gate(Args... args) {
static internal::Gate<typename T::FloatType> create_gate(Args... args) {
return {std::make_shared<const T>(args...)};
}
};
Expand All @@ -195,9 +197,10 @@ class GateFactory {

namespace gate {

inline internal::Gate X(std::uint64_t target,
const std::vector<std::uint64_t>& control_qubits = {}) {
return internal::GateFactory::create_gate<internal::XGateImpl>(
template <std::floating_point FloatType>
inline internal::Gate<FloatType> X(std::uint64_t target,
const std::vector<std::uint64_t>& control_qubits = {}) {
return internal::GateFactory::create_gate<internal::XGateImpl<FloatType>>(
internal::vector_to_mask({target}), internal::vector_to_mask(control_qubits));
}

Expand All @@ -211,7 +214,7 @@ int main() {
std::uint64_t n_qubits = 3;
scaluq::StateVector<double> state(n_qubits);
state.load({0, 1, 2, 3, 4, 5, 6, 7});
auto x_gate = scaluq::gate::X(1, {0, 2});
auto x_gate = scaluq::gate::X<double>(1, {0, 2});
x_gate->update_quantum_state(state);

std::cout << state << std::endl;
Expand Down
50 changes: 41 additions & 9 deletions scaluq/gate/gate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,44 +8,73 @@
namespace scaluq {
namespace internal {
// forward declarations

template <std::floating_point FloatType>
class GateBase;

template <typename T>
concept GateImpl = std::derived_from<T, GateBase>;
concept GateImpl = std::derived_from<T, GateBase<typename T::FloatType>>;

class IGateImpl;
template <std::floating_point FloatType>
class GlobalPhaseGateImpl;
template <std::floating_point FloatType>
class XGateImpl;
template <std::floating_point FloatType>
class YGateImpl;
template <std::floating_point FloatType>
class ZGateImpl;
template <std::floating_point FloatType>
class HGateImpl;
template <std::floating_point FloatType>
class SGateImpl;
template <std::floating_point FloatType>
class SdagGateImpl;
template <std::floating_point FloatType>
class TGateImpl;
template <std::floating_point FloatType>
class TdagGateImpl;
template <std::floating_point FloatType>
class SqrtXGateImpl;
template <std::floating_point FloatType>
class SqrtXdagGateImpl;
template <std::floating_point FloatType>
class SqrtYGateImpl;
template <std::floating_point FloatType>
class SqrtYdagGateImpl;
template <std::floating_point FloatType>
class P0GateImpl;
template <std::floating_point FloatType>
class P1GateImpl;
template <std::floating_point FloatType>
class RXGateImpl;
template <std::floating_point FloatType>
class RYGateImpl;
template <std::floating_point FloatType>
class RZGateImpl;
template <std::floating_point FloatType>
class U1GateImpl;
template <std::floating_point FloatType>
class U2GateImpl;
template <std::floating_point FloatType>
class U3GateImpl;
template <std::floating_point FloatType>
class OneTargetMatrixGateImpl;
template <std::floating_point FloatType>
class SwapGateImpl;
template <std::floating_point FloatType>
class TwoTargetMatrixGateImpl;
template <std::floating_point FloatType>
class PauliGateImpl;
template <std::floating_point FloatType>
class PauliRotationGateImpl;
template <std::floating_point FloatType>
class ProbablisticGateImpl;

template <GateImpl T>
class GatePtr;
} // namespace internal
using Gate = internal::GatePtr<internal::GateBase>;
template <std::floating_point FloatType>
using Gate = GatePtr<GateBase<FloatType>>;

enum class GateType {
Unknown,
Expand Down Expand Up @@ -145,10 +174,16 @@ constexpr GateType get_gate_type() {
}

namespace internal {
class GateBase : public std::enable_shared_from_this<GateBase> {
// GateBase テンプレートクラス
template <std::floating_point _FloatType>
class GateBase : public std::enable_shared_from_this<GateBase<_FloatType>> {
public:
using FloatType = _FloatType;

protected:
std::uint64_t _target_mask, _control_mask;
void check_qubit_mask_within_bounds(const StateVector& state_vector) const {

void check_qubit_mask_within_bounds(const StateVector<FloatType>& state_vector) const {
std::uint64_t full_mask = (1ULL << state_vector.n_qubits()) - 1;
if ((_target_mask | _control_mask) > full_mask) [[unlikely]] {
throw std::runtime_error(
Expand Down Expand Up @@ -198,10 +233,7 @@ class GateBase : public std::enable_shared_from_this<GateBase> {
return _target_mask | _control_mask;
}

[[nodiscard]] virtual Gate get_inverse() const = 0;
[[nodiscard]] virtual internal::ComplexMatrix get_matrix() const = 0;

virtual void update_quantum_state(StateVector& state_vector) const = 0;
virtual void update_quantum_state(StateVector<FloatType>& state_vector) const = 0;

[[nodiscard]] virtual std::string to_string(const std::string& indent = "") const = 0;
};
Expand Down
Loading

0 comments on commit 5228ed3

Please sign in to comment.