Skip to content

Commit

Permalink
avoid batch interleaved path for now
Browse files Browse the repository at this point in the history
  • Loading branch information
FMarno committed Mar 7, 2024
1 parent 037fa6f commit 2e7f777
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions src/portfft/dispatcher/workgroup_dispatcher.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,11 @@ namespace detail {
* @param is_batch_interleaved is the input data layout batch interleaved
* @param workgroup_size The size of the work-group. Must be divisible by 2.
*/
PORTFFT_INLINE constexpr Idx get_num_batches_in_local_mem_workgroup(bool is_batch_interleaved,
Idx workgroup_size) noexcept {
return is_batch_interleaved ? workgroup_size / 2 : 1;
PORTFFT_INLINE constexpr Idx get_num_batches_in_local_mem_workgroup(bool /*is_batch_interleaved*/,
Idx /*workgroup_size*/) noexcept {
// TODO reenable when tests are passing
// return is_batch_interleaved ? workgroup_size / 2 : 1;
return 1;
}

/**
Expand Down Expand Up @@ -110,8 +112,9 @@ PORTFFT_INLINE void workgroup_impl(const T* input, T* output, const T* input_ima
const IdxGlobal input_distance = kh.get_specialization_constant<detail::SpecConstInputDistance>();
const IdxGlobal output_distance = kh.get_specialization_constant<detail::SpecConstOutputDistance>();

const bool is_input_batch_interleaved = input_stride == n_transforms && input_distance == 1;
const bool is_input_packed = input_stride == 1 && input_distance == fft_size;
// TODO reable when tests are passing
const bool is_input_batch_interleaved = false; // input_stride == n_transforms && input_distance == 1;
const bool is_input_packed = input_stride == 1 && input_distance == fft_size;

global_data.log_message_global(__func__, "entered", "fft_size", fft_size, "n_transforms", n_transforms);
Idx num_workgroups = static_cast<Idx>(global_data.it.get_group_range(0));
Expand Down Expand Up @@ -280,8 +283,8 @@ struct committed_descriptor_impl<Scalar, Domain>::run_kernel_struct<SubgroupSize
PORTFFT_LOG_FUNCTION_ENTRY();
auto& kernel_data = compute_direction == direction::FORWARD ? dimension_data.forward_kernels.at(0)
: dimension_data.backward_kernels.at(0);
Idx num_batches_in_local_mem =
input_layout == layout::BATCH_INTERLEAVED ? kernel_data.used_sg_size * PORTFFT_SGS_IN_WG / 2 : 1;
Idx num_batches_in_local_mem = detail::get_num_batches_in_local_mem_workgroup(
input_layout == layout::BATCH_INTERLEAVED, kernel_data.used_sg_size * PORTFFT_SGS_IN_WG);
constexpr detail::memory Mem = std::is_pointer_v<TOut> ? detail::memory::USM : detail::memory::BUFFER;
Scalar* twiddles = kernel_data.twiddles_forward.get();
std::size_t local_elements =
Expand Down Expand Up @@ -355,8 +358,8 @@ struct committed_descriptor_impl<Scalar, Domain>::num_scalars_in_local_mem_struc
// working memory + twiddles for subgroup impl for the two sizes
Idx num_batches_in_local_mem = detail::get_num_batches_in_local_mem_workgroup(
input_layout == layout::BATCH_INTERLEAVED, used_sg_size * PORTFFT_SGS_IN_WG);
return detail::pad_local(static_cast<std::size_t>(2 * num_batches_in_local_mem) * length,
bank_lines_per_pad_wg(2 * static_cast<std::size_t>(sizeof(Scalar)) * m)) +
const auto bank_lines_per_pad = bank_lines_per_pad_wg(2 * static_cast<std::size_t>(sizeof(Scalar)) * m);
return detail::pad_local(static_cast<std::size_t>(2 * num_batches_in_local_mem) * length, bank_lines_per_pad) +
2 * (m + n);
}
};
Expand Down

0 comments on commit 2e7f777

Please sign in to comment.