Skip to content

Commit

Permalink
fix to compile gate/paramgate
Browse files Browse the repository at this point in the history
  • Loading branch information
KowerKoint committed Jan 24, 2025
1 parent 651958a commit 02ca9bf
Show file tree
Hide file tree
Showing 17 changed files with 518 additions and 478 deletions.
14 changes: 7 additions & 7 deletions include/scaluq/gate/gate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,10 @@ constexpr GateType get_gate_type() {

namespace internal {
// GateBase テンプレートクラス
template <Precision Prec>
class GateBase : public std::enable_shared_from_this<GateBase<Prec>> {
template <Precision _Prec>
class GateBase : public std::enable_shared_from_this<GateBase<_Prec>> {
public:
constexpr static Precision Prec = Prec;
constexpr static Precision Prec = _Prec;
using FloatType = Float<Prec>;
using ComplexType = Complex<Prec>;

Expand Down Expand Up @@ -205,7 +205,7 @@ class GateBase : public std::enable_shared_from_this<GateBase<Prec>> {
}

[[nodiscard]] virtual std::shared_ptr<const GateBase<Prec>> get_inverse() const = 0;
[[nodiscard]] virtual internal::ComplexMatrix get_matrix() const = 0;
[[nodiscard]] virtual ComplexMatrix get_matrix() const = 0;

virtual void update_quantum_state(StateVector<Prec>& state_vector) const = 0;
virtual void update_quantum_state(StateVectorBatched<Prec>& states) const = 0;
Expand All @@ -232,7 +232,7 @@ class GatePtr {
GateType _gate_type;

public:
constexpr Precision Prec = T::Prec;
constexpr static Precision Prec = T::Prec;
using FloatType = Float<Prec>;
using ComplexType = Complex<Prec>;
GatePtr() : _gate_ptr(nullptr), _gate_type(get_gate_type<T, Prec>()) {}
Expand All @@ -241,7 +241,7 @@ class GatePtr {
if constexpr (std::is_same_v<T, U>) {
_gate_type = get_gate_type<T, Prec>();
_gate_ptr = gate_ptr;
} else if constexpr (std::is_same_v<T, internal::GateBase<Prec>>) {
} else if constexpr (std::is_same_v<T, GateBase<Prec>>) {
// upcast
_gate_type = get_gate_type<U, Prec>();
_gate_ptr = std::static_pointer_cast<const T>(gate_ptr);
Expand All @@ -258,7 +258,7 @@ class GatePtr {
if constexpr (std::is_same_v<T, U>) {
_gate_type = gate._gate_type;
_gate_ptr = gate._gate_ptr;
} else if constexpr (std::is_same_v<T, internal::GateBase<Prec>>) {
} else if constexpr (std::is_same_v<T, GateBase<Prec>>) {
// upcast
_gate_type = gate._gate_type;
_gate_ptr = std::static_pointer_cast<const T>(gate._gate_ptr);
Expand Down
30 changes: 15 additions & 15 deletions include/scaluq/gate/gate_pauli.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class PauliGateImpl : public GateBase<Prec> {
std::shared_ptr<const GateBase<Prec>> get_inverse() const override {
return this->shared_from_this();
}
internal::ComplexMatrix get_matrix() const override { return this->_pauli.get_matrix(); }
ComplexMatrix get_matrix() const override { return this->_pauli.get_matrix(); }

void update_quantum_state(StateVector<Prec>& state_vector) const override;
void update_quantum_state(StateVectorBatched<Prec>& states) const override;
Expand Down Expand Up @@ -59,7 +59,7 @@ class PauliRotationGateImpl : public GateBase<Prec> {
this->_control_mask, _pauli, -_angle);
}

internal::ComplexMatrix get_matrix() const override;
ComplexMatrix get_matrix() const override;

void update_quantum_state(StateVector<Prec>& state_vector) const override;
void update_quantum_state(StateVectorBatched<Prec>& states) const override;
Expand All @@ -81,35 +81,35 @@ template <Precision Prec>
using PauliRotationGate = internal::GatePtr<internal::PauliRotationGateImpl<Prec>>;

namespace internal {
#define DECLARE_GET_FROM_JSON_PAULIGATE_WITH_TYPE(Type) \
#define DECLARE_GET_FROM_JSON_PAULIGATE_WITH_PRECISION(Prec) \
template <> \
inline std::shared_ptr<const PauliGateImpl<Type>> get_from_json(const Json& j) { \
inline std::shared_ptr<const PauliGateImpl<Prec>> get_from_json(const Json& j) { \
auto controls = j.at("control").get<std::vector<std::uint64_t>>(); \
auto pauli = j.at("pauli").get<PauliOperator<Type>>(); \
return std::make_shared<const PauliGateImpl<Type>>(vector_to_mask(controls), pauli); \
auto pauli = j.at("pauli").get<PauliOperator<Prec>>(); \
return std::make_shared<const PauliGateImpl<Prec>>(vector_to_mask(controls), pauli); \
} \
template <> \
inline std::shared_ptr<const PauliRotationGateImpl<Type>> get_from_json(const Json& j) { \
inline std::shared_ptr<const PauliRotationGateImpl<Prec>> get_from_json(const Json& j) { \
auto controls = j.at("control").get<std::vector<std::uint64_t>>(); \
auto pauli = j.at("pauli").get<PauliOperator<Type>>(); \
auto angle = j.at("angle").get<Type>(); \
return std::make_shared<const PauliRotationGateImpl<Type>>( \
auto pauli = j.at("pauli").get<PauliOperator<Prec>>(); \
auto angle = j.at("angle").get<double>(); \
return std::make_shared<const PauliRotationGateImpl<Prec>>( \
vector_to_mask(controls), pauli, angle); \
}

#ifdef SCALUQ_FLOAT16
DECLARE_GET_FROM_JSON_PAULIGATE_WITH_TYPE(F16)
DECLARE_GET_FROM_JSON_PAULIGATE_WITH_PRECISION(Precision::F16)
#endif
#ifdef SCALUQ_FLOAT32
DECLARE_GET_FROM_JSON_PAULIGATE_WITH_TYPE(F32)
DECLARE_GET_FROM_JSON_PAULIGATE_WITH_PRECISION(Precision::F32)
#endif
#ifdef SCALUQ_FLOAT64
DECLARE_GET_FROM_JSON_PAULIGATE_WITH_TYPE(F64)
DECLARE_GET_FROM_JSON_PAULIGATE_WITH_PRECISION(Precision::F64)
#endif
#ifdef SCALUQ_BFLOAT16
DECLARE_GET_FROM_JSON_PAULIGATE_WITH_TYPE(BF16)
DECLARE_GET_FROM_JSON_PAULIGATE_WITH_PRECISION(Precision::BF16)
#endif
#undef DECLARE_GET_FROM_JSON_PAULIGATE_WITH_TYPE
#undef DECLARE_GET_FROM_JSON_PAULIGATE_WITH_PRECISION

} // namespace internal

Expand Down
20 changes: 10 additions & 10 deletions include/scaluq/gate/gate_probablistic.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class ProbablisticGateImpl : public GateBase<Prec> {
}

std::shared_ptr<const GateBase<Prec>> get_inverse() const override;
internal::ComplexMatrix get_matrix() const override {
ComplexMatrix get_matrix() const override {
throw std::runtime_error(
"ProbablisticGateImpl::get_matrix(): This function must not be used in "
"ProbablisticGateImpl.");
Expand All @@ -76,27 +76,27 @@ using ProbablisticGate = internal::GatePtr<internal::ProbablisticGateImpl<Prec>>

namespace internal {

#define DECLARE_GET_FROM_JSON_PROBGATE_WITH_TYPE(Type) \
#define DECLARE_GET_FROM_JSON_PROBGATE_WITH_PRECISION(Prec) \
template <> \
inline std::shared_ptr<const ProbablisticGateImpl<Type>> get_from_json(const Json& j) { \
inline std::shared_ptr<const ProbablisticGateImpl<Prec>> get_from_json(const Json& j) { \
auto distribution = j.at("distribution").get<std::vector<double>>(); \
auto gate_list = j.at("gate_list").get<std::vector<Gate<Type>>>(); \
return std::make_shared<const ProbablisticGateImpl<Type>>(distribution, gate_list); \
auto gate_list = j.at("gate_list").get<std::vector<Gate<Prec>>>(); \
return std::make_shared<const ProbablisticGateImpl<Prec>>(distribution, gate_list); \
}

#ifdef SCALUQ_FLOAT16
DECLARE_GET_FROM_JSON_PROBGATE_WITH_TYPE(F16)
DECLARE_GET_FROM_JSON_PROBGATE_WITH_PRECISION(Precision::F16)
#endif
#ifdef SCALUQ_FLOAT32
DECLARE_GET_FROM_JSON_PROBGATE_WITH_TYPE(F32)
DECLARE_GET_FROM_JSON_PROBGATE_WITH_PRECISION(Precision::F32)
#endif
#ifdef SCALUQ_FLOAT64
DECLARE_GET_FROM_JSON_PROBGATE_WITH_TYPE(F64)
DECLARE_GET_FROM_JSON_PROBGATE_WITH_PRECISION(Precision::F64)
#endif
#ifdef SCALUQ_BFLOAT16
DECLARE_GET_FROM_JSON_PROBGATE_WITH_TYPE(BF16)
DECLARE_GET_FROM_JSON_PROBGATE_WITH_PRECISION(Precision::BF16)
#endif
#undef DECLARE_GET_FROM_JSON_PROBGATE_WITH_TYPE
#undef DECLARE_GET_FROM_JSON_PROBGATE_WITH_PRECISION

} // namespace internal

Expand Down
Loading

0 comments on commit 02ca9bf

Please sign in to comment.