Skip to content

Commit

Permalink
fix to compile both on cpu/cuda
Browse files Browse the repository at this point in the history
  • Loading branch information
KowerKoint committed Jan 24, 2025
1 parent a2f7a01 commit 76dac86
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 46 deletions.
2 changes: 1 addition & 1 deletion include/scaluq/gate/gate_pauli.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ namespace internal {
auto pauli = j.at("pauli").get<PauliOperator<Prec>>(); \
auto angle = j.at("angle").get<double>(); \
return std::make_shared<const PauliRotationGateImpl<Prec>>( \
vector_to_mask(controls), pauli, angle); \
vector_to_mask(controls), pauli, static_cast<Float<Prec>>(angle)); \
}

#ifdef SCALUQ_FLOAT16
Expand Down
78 changes: 42 additions & 36 deletions include/scaluq/gate/gate_standard.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,8 @@ DECLARE_GET_FROM_JSON_IGATE_WITH_PRECISION(Precision::BF16)
inline std::shared_ptr<const GlobalPhaseGateImpl<Prec>> get_from_json(const Json& j) { \
auto controls = j.at("control").get<std::vector<std::uint64_t>>(); \
double phase = j.at("phase").get<double>(); \
return std::make_shared<const GlobalPhaseGateImpl<Prec>>(vector_to_mask(controls), phase); \
return std::make_shared<const GlobalPhaseGateImpl<Prec>>(vector_to_mask(controls), \
static_cast<Float<Prec>>(phase)); \
}
#ifdef SCALUQ_FLOAT16
DECLARE_GET_FROM_JSON_GLOBALPHASEGATE_WITH_PRECISION(Precision::F16)
Expand Down Expand Up @@ -719,14 +720,14 @@ DECALRE_GET_FROM_JSON_EACH_SINGLETARGETGATE_WITH_PRECISION(Precision::BF16)
#undef DECLARE_GET_FROM_JSON_SINGLETARGETGATE_WITH_PRECISION
#undef DECLARE_GET_FROM_JSON_EACH_SINGLETARGETGATE_WITH_PRECISION

#define DECLARE_GET_FROM_JSON_RGATE_WITH_PRECISION(Impl, Prec) \
template <> \
inline std::shared_ptr<const Impl<Prec>> get_from_json(const Json& j) { \
auto targets = j.at("target").get<std::vector<std::uint64_t>>(); \
auto controls = j.at("control").get<std::vector<std::uint64_t>>(); \
double angle = j.at("angle").get<double>(); \
return std::make_shared<const Impl<Prec>>( \
vector_to_mask(targets), vector_to_mask(controls), angle); \
#define DECLARE_GET_FROM_JSON_RGATE_WITH_PRECISION(Impl, Prec) \
template <> \
inline std::shared_ptr<const Impl<Prec>> get_from_json(const Json& j) { \
auto targets = j.at("target").get<std::vector<std::uint64_t>>(); \
auto controls = j.at("control").get<std::vector<std::uint64_t>>(); \
double angle = j.at("angle").get<double>(); \
return std::make_shared<const Impl<Prec>>( \
vector_to_mask(targets), vector_to_mask(controls), static_cast<Float<Prec>>(angle)); \
}
#define DECLARE_GET_FROM_JSON_EACH_RGATE_WITH_PRECISION(Prec) \
DECLARE_GET_FROM_JSON_RGATE_WITH_PRECISION(RXGateImpl, Prec) \
Expand All @@ -747,33 +748,38 @@ DECLARE_GET_FROM_JSON_EACH_RGATE_WITH_PRECISION(Precision::BF16)
#undef DECLARE_GET_FROM_JSON_RGATE
#undef DECLARE_GET_FROM_JSON_EACH_RGATE_WITH_PRECISION

#define DECLARE_GET_FROM_JSON_UGATE_WITH_PRECISION(Prec) \
template <> \
inline std::shared_ptr<const U1GateImpl<Prec>> get_from_json(const Json& j) { \
auto targets = j.at("target").get<std::vector<std::uint64_t>>(); \
auto controls = j.at("control").get<std::vector<std::uint64_t>>(); \
double theta = j.at("theta").get<double>(); \
return std::make_shared<const U1GateImpl<Prec>>( \
vector_to_mask(targets), vector_to_mask(controls), theta); \
} \
template <> \
inline std::shared_ptr<const U2GateImpl<Prec>> get_from_json(const Json& j) { \
auto targets = j.at("target").get<std::vector<std::uint64_t>>(); \
auto controls = j.at("control").get<std::vector<std::uint64_t>>(); \
double theta = j.at("theta").get<double>(); \
double phi = j.at("phi").get<double>(); \
return std::make_shared<const U2GateImpl<Prec>>( \
vector_to_mask(targets), vector_to_mask(controls), theta, phi); \
} \
template <> \
inline std::shared_ptr<const U3GateImpl<Prec>> get_from_json(const Json& j) { \
auto targets = j.at("target").get<std::vector<std::uint64_t>>(); \
auto controls = j.at("control").get<std::vector<std::uint64_t>>(); \
double theta = j.at("theta").get<double>(); \
double phi = j.at("phi").get<double>(); \
double lambda = j.at("lambda").get<double>(); \
return std::make_shared<const U3GateImpl<Prec>>( \
vector_to_mask(targets), vector_to_mask(controls), theta, phi, lambda); \
#define DECLARE_GET_FROM_JSON_UGATE_WITH_PRECISION(Prec) \
template <> \
inline std::shared_ptr<const U1GateImpl<Prec>> get_from_json(const Json& j) { \
auto targets = j.at("target").get<std::vector<std::uint64_t>>(); \
auto controls = j.at("control").get<std::vector<std::uint64_t>>(); \
double theta = j.at("theta").get<double>(); \
return std::make_shared<const U1GateImpl<Prec>>( \
vector_to_mask(targets), vector_to_mask(controls), static_cast<Float<Prec>>(theta)); \
} \
template <> \
inline std::shared_ptr<const U2GateImpl<Prec>> get_from_json(const Json& j) { \
auto targets = j.at("target").get<std::vector<std::uint64_t>>(); \
auto controls = j.at("control").get<std::vector<std::uint64_t>>(); \
double theta = j.at("theta").get<double>(); \
double phi = j.at("phi").get<double>(); \
return std::make_shared<const U2GateImpl<Prec>>(vector_to_mask(targets), \
vector_to_mask(controls), \
static_cast<Float<Prec>>(theta), \
static_cast<Float<Prec>>(phi)); \
} \
template <> \
inline std::shared_ptr<const U3GateImpl<Prec>> get_from_json(const Json& j) { \
auto targets = j.at("target").get<std::vector<std::uint64_t>>(); \
auto controls = j.at("control").get<std::vector<std::uint64_t>>(); \
double theta = j.at("theta").get<double>(); \
double phi = j.at("phi").get<double>(); \
double lambda = j.at("lambda").get<double>(); \
return std::make_shared<const U3GateImpl<Prec>>(vector_to_mask(targets), \
vector_to_mask(controls), \
static_cast<Float<Prec>>(theta), \
static_cast<Float<Prec>>(phi), \
static_cast<Float<Prec>>(lambda)); \
}
#ifdef SCALUQ_FLOAT16
DECLARE_GET_FROM_JSON_UGATE_WITH_PRECISION(Precision::F16)
Expand Down
2 changes: 1 addition & 1 deletion include/scaluq/gate/param_gate_pauli.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ namespace internal {
auto pauli = j.at("pauli").get<PauliOperator<Prec>>(); \
auto param_coef = j.at("param_coef").get<double>(); \
return std::make_shared<const ParamPauliRotationGateImpl<Prec>>( \
vector_to_mask(controls), pauli, param_coef); \
vector_to_mask(controls), pauli, static_cast<Float<Prec>>(param_coef)); \
}

#ifdef SCALUQ_FLOAT16
Expand Down
17 changes: 9 additions & 8 deletions include/scaluq/gate/param_gate_standard.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,15 @@ using ParamRZGate = internal::ParamGatePtr<internal::ParamRZGateImpl<Prec>>;

namespace internal {

#define DECLARE_GET_FROM_JSON_PARAM_RGATE_WITH_PRECISION(Impl, Prec) \
template <> \
inline std::shared_ptr<const Impl<Prec>> get_from_json(const Json& j) { \
auto targets = j.at("target").get<std::vector<std::uint64_t>>(); \
auto controls = j.at("control").get<std::vector<std::uint64_t>>(); \
auto param_coef = j.at("param_coef").get<double>(); \
return std::make_shared<const Impl<Prec>>( \
vector_to_mask(targets), vector_to_mask(controls), param_coef); \
#define DECLARE_GET_FROM_JSON_PARAM_RGATE_WITH_PRECISION(Impl, Prec) \
template <> \
inline std::shared_ptr<const Impl<Prec>> get_from_json(const Json& j) { \
auto targets = j.at("target").get<std::vector<std::uint64_t>>(); \
auto controls = j.at("control").get<std::vector<std::uint64_t>>(); \
auto param_coef = j.at("param_coef").get<double>(); \
return std::make_shared<const Impl<Prec>>(vector_to_mask(targets), \
vector_to_mask(controls), \
static_cast<Float<Prec>>(param_coef)); \
}

#define DECLARE_GET_FROM_JSON_EACH_PARAM_RGATE_WITH_PRECISION(Prec) \
Expand Down

0 comments on commit 76dac86

Please sign in to comment.