Skip to content

Commit

Permalink
add ostream difinition of GatePtr
Browse files Browse the repository at this point in the history
  • Loading branch information
gandalfr-KY committed Sep 3, 2024
1 parent a24aac3 commit bccc0e9
Show file tree
Hide file tree
Showing 6 changed files with 189 additions and 6 deletions.
14 changes: 12 additions & 2 deletions exe/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,18 @@ using namespace scaluq;
using namespace std;

void run() {
std::uint64_t n_qubits = 5;
auto state = StateVector::Haar_random_state(n_qubits);
auto x_gate = gate::X(2);
std::cout << x_gate << std::endl;
auto y_gate = gate::Y(2);
std::cout << y_gate << std::endl;
auto swap_gate = gate::Swap(2, 3, {4, 6});
std::cout << swap_gate << "\n\n";

auto prob_gate = gate::Probablistic({0.1, 0.1, 0.8}, {x_gate, y_gate, swap_gate});
std::cout << prob_gate << "\n\n";

auto prob_prob_gate = gate::Probablistic({0.5, 0.5}, {x_gate, prob_gate});
std::cout << prob_prob_gate << "\n\n";
}

int main() {
Expand Down
7 changes: 6 additions & 1 deletion scaluq/gate/gate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class SwapGateImpl;
class TwoTargetMatrixGateImpl;
class PauliGateImpl;
class PauliRotationGateImpl;
class ProbablisticGateImpl;

template <GateImpl T>
class GatePtr;
Expand Down Expand Up @@ -80,7 +81,8 @@ enum class GateType {
Swap,
TwoTargetMatrix,
Pauli,
PauliRotation
PauliRotation,
Probablistic
};

template <internal::GateImpl T>
Expand Down Expand Up @@ -119,6 +121,7 @@ constexpr GateType get_gate_type() {
if constexpr (std::is_same_v<T, internal::PauliGateImpl>) return GateType::Pauli;
if constexpr (std::is_same_v<T, internal::PauliRotationGateImpl>)
return GateType::PauliRotation;
if constexpr (std::is_same_v<T, internal::ProbablisticGateImpl>) return GateType::Probablistic;
static_assert("unknown GateImpl");
return GateType::Unknown;
}
Expand Down Expand Up @@ -226,6 +229,8 @@ class GatePtr {
}
return _gate_ptr.get();
}

// 依存関係により、operator<< の定義は gate_factory.hpp に定義
};
} // namespace internal

Expand Down
131 changes: 131 additions & 0 deletions scaluq/gate/gate_factory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,4 +215,135 @@ inline Gate Probablistic(const std::vector<double>& distribution,
gate_list);
}
} // namespace gate

template <internal::GateImpl T>
std::ostream& operator<<(std::ostream& os, const internal::GatePtr<T>& obj) {
std::string indent = " ";
if (obj.gate_type() == GateType::Probablistic) {
const auto prob_gate = ProbablisticGate(obj);
const auto distribution = prob_gate->distribution();
const auto gates = prob_gate->gate_list();
os << "Gate Type: Probablistic\n";
for (std::size_t i = 0; i < distribution.size(); ++i) {
std::ostringstream gate_ss;
gate_ss << gates[i];
std::istringstream gate_is(gate_ss.str());
std::string line;
// os << indent << "--------------------\n";
os << indent << "Probability: " << distribution[i] << "\n";
while (std::getline(gate_is, line)) {
os << indent << line << (gate_is.peek() == EOF ? "" : "\n");
}
}
return os;
}
auto targets = internal::mask_to_vector(obj->target_qubit_mask());
auto controls = internal::mask_to_vector(obj->control_qubit_mask());
os << "Gate Type: ";
switch (obj.gate_type()) {
case GateType::I:
os << "I";
break;
case GateType::GlobalPhase:
os << "GlobalPhase";
break;
case GateType::X:
os << "X";
break;
case GateType::Y:
os << "Y";
break;
case GateType::Z:
os << "Z";
break;
case GateType::H:
os << "H";
break;
case GateType::S:
os << "S";
break;
case GateType::Sdag:
os << "Sdag";
break;
case GateType::T:
os << "T";
break;
case GateType::Tdag:
os << "Tdag";
break;
case GateType::SqrtX:
os << "SqrtX";
break;
case GateType::SqrtXdag:
os << "SqrtXdag";
break;
case GateType::SqrtY:
os << "SqrtY";
break;
case GateType::SqrtYdag:
os << "SqrtYdag";
break;
case GateType::P0:
os << "P0";
break;
case GateType::P1:
os << "P1";
break;
case GateType::RX:
os << "RX";
break;
case GateType::RY:
os << "RY";
break;
case GateType::RZ:
os << "RZ";
break;
case GateType::U1:
os << "U1";
break;
case GateType::U2:
os << "U2";
break;
case GateType::U3:
os << "U3";
break;
case GateType::OneTargetMatrix:
os << "OneTargetMatrix";
break;
case GateType::CX:
os << "CX";
break;
case GateType::CZ:
os << "CZ";
break;
case GateType::CCX:
os << "CCX";
break;
case GateType::Swap:
os << "Swap";
break;
case GateType::TwoTargetMatrix:
os << "TwoTargetMatrix";
break;
case GateType::Pauli:
os << "Pauli";
break;
case GateType::PauliRotation:
os << "PauliRotation";
break;
case GateType::Unknown:
default:
os << "Unknown";
break;
}
os << "\n" << indent << "Target Qubits: {";
for (std::uint32_t i = 0; i < targets.size(); ++i)
os << targets[i] << (i == targets.size() - 1 ? "" : ", ");
os << "}\n" << indent << "Control Qubits: {";
for (std::uint32_t i = 0; i < controls.size(); ++i)
os << controls[i] << (i == controls.size() - 1 ? "" : ", ");
os << "}";
return os;
}

} // namespace scaluq
37 changes: 37 additions & 0 deletions scaluq/gate/param_gate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,43 @@ class ParamGatePtr {
}
return _param_gate_ptr.get();
}

friend std::ostream& operator<<(std::ostream& os, const ParamGatePtr& obj) {
if (!obj._param_gate_ptr) {
os << "Gate Type: Null";
return os;
}
auto targets = internal::mask_to_vector(obj->target_qubit_mask());
auto controls = internal::mask_to_vector(obj->control_qubit_mask());
os << "Gate Type: ";
switch (obj.param_gate_type()) {
case ParamGateType::ParamRX:
os << "ParamRX";
break;
case ParamGateType::ParamRY:
os << "ParamRY";
break;
case ParamGateType::ParamRZ:
os << "ParamRZ";
break;
case ParamGateType::ParamPauliRotation:
os << "ParamPauliRotation";
break;
default:
os << "Undefined";
break;
}
os << "\n"
"Target Qubits: {";
for (std::uint32_t i = 0; i < targets.size(); ++i)
os << targets[i] << (i == targets.size() - 1 ? "" : ", ");
os << "}\n"
"Control Qubits: {";
for (std::uint32_t i = 0; i < controls.size(); ++i)
os << controls[i] << (i == controls.size() - 1 ? "" : ", ");
os << "}";
return os;
}
};
} // namespace internal

Expand Down
2 changes: 1 addition & 1 deletion scaluq/state/state_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ std::string StateVector::to_string() const {
}
return tmp;
}(i, _n_qubits)
<< ": " << amp[i] << std::endl;
<< ": " << amp[i] << (i < _dim - 1 ? "\n" : "");
}
return os.str();
}
Expand Down
4 changes: 2 additions & 2 deletions scaluq/state/state_vector_batched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ std::string StateVectorBatched::to_string() const {
for (std::uint64_t b = 0; b < _batch_size; ++b) {
StateVector tmp(_n_qubits);
os << "--------------------\n";
os << " * Batch_id : " << b << '\n';
os << " * Batch id : " << b << '\n';
os << " * State vector : \n";
for (std::uint64_t i = 0; i < _dim; ++i) {
os <<
Expand All @@ -378,7 +378,7 @@ std::string StateVectorBatched::to_string() const {
}
return tmp;
}(i, _n_qubits)
<< ": " << states_h(b, i) << std::endl;
<< ": " << states_h(b, i) << (b < _batch_size - 1 || i < _dim - 1 ? "\n" : "");
}
}
return os.str();
Expand Down

0 comments on commit bccc0e9

Please sign in to comment.