From 66dd2d4fa83552781dc2a2d1cb3d15b7a0494edc Mon Sep 17 00:00:00 2001 From: Glacialte Date: Thu, 11 Jan 2024 09:43:36 +0000 Subject: [PATCH] use single loop in fusedswap_gate(update_ops_npair_qubit.cpp) --- qulacs/gate/update_ops_npair_qubit.cpp | 32 ++++++++------------------ 1 file changed, 10 insertions(+), 22 deletions(-) diff --git a/qulacs/gate/update_ops_npair_qubit.cpp b/qulacs/gate/update_ops_npair_qubit.cpp index 303fe0d9..601d7e0a 100644 --- a/qulacs/gate/update_ops_npair_qubit.cpp +++ b/qulacs/gate/update_ops_npair_qubit.cpp @@ -16,32 +16,20 @@ void fusedswap_gate(UINT target_qubit_index_0, lower_index = std::min(target_qubit_index_0, target_qubit_index_1); assert(upper_index > (lower_index + block_size - 1)); assert(n_qubits > (upper_index + block_size - 1)); - const UINT kblk_dim = 1ULL << (n_qubits - upper_index); - const UINT jblk_dim = 1ULL << (upper_index - lower_index); - const UINT iblk_dim = 1ULL << lower_index; const UINT mask_block = (1 << block_size) - 1; auto amplitudes = state.amplitudes_raw(); + const UINT kblk_mask = mask_block << upper_index; + const UINT jblk_mask = mask_block << lower_index; + const UINT else_mask = (1 << n_qubits) - 1 - kblk_mask - jblk_mask; Kokkos::parallel_for( - kblk_dim, KOKKOS_LAMBDA(const UINT& kblk) { - const UINT kblk_masked = kblk & mask_block; - const UINT kblk_head = kblk - kblk_masked; - const UINT jblk_start = kblk_masked + 1; - - Kokkos::parallel_for( - jblk_dim, KOKKOS_LAMBDA(const UINT& jblk) { - const UINT jblk_masked = jblk & mask_block; - const UINT jblk_head = jblk - jblk_masked; - if (jblk_masked < jblk_start) return; - - UINT si = (kblk << upper_index) + (jblk << lower_index); - UINT ti = ((kblk_head + jblk_masked) << upper_index) + - ((jblk_head + kblk_masked) << lower_index); - Kokkos::parallel_for( - iblk_dim, KOKKOS_LAMBDA(const UINT& i) { - Kokkos::Experimental::swap(amplitudes[si + i], amplitudes[ti + i]); - }); - }); + 1 << n_qubits, KOKKOS_LAMBDA(const UINT& i) { + const UINT kblk = (i & kblk_mask) >> upper_index; + const UINT jblk = (i & jblk_mask) >> lower_index; + if (jblk > kblk) { + const UINT index = (i & else_mask) | jblk << upper_index | kblk << lower_index; + Kokkos::Experimental::swap(amplitudes[i], amplitudes[index]); + } }); } } // namespace qulacs