Skip to content

Commit

Permalink
[SYCL] Enable mapping of group load/store functions to SPIRV built-in…
Browse files Browse the repository at this point in the history
…s for local address space (#16653)

Extension:
https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_local_block_io.html

Currently these built-ins for local address space are not supported by
cpu/fpga backends, so introduce undocumented `native_local_block_io`
property which allows to enable mapping to those built-ins. If this
property is not provided then implementation falls back to naive
approach.
  • Loading branch information
againull authored Jan 24, 2025
1 parent 4d3d4e6 commit 5edcf74
Show file tree
Hide file tree
Showing 12 changed files with 2,996 additions and 1,360 deletions.
41 changes: 41 additions & 0 deletions sycl/include/sycl/__spirv/spirv_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,47 @@ template <typename dataT>
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL void
__spirv_SubgroupBlockWriteINTEL(__attribute__((opencl_global)) uint64_t *Ptr,
dataT Data) noexcept;

template <typename dataT>
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL dataT
__spirv_SubgroupBlockReadINTEL(const __attribute__((opencl_local))
uint8_t *Ptr) noexcept;

template <typename dataT>
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL void
__spirv_SubgroupBlockWriteINTEL(__attribute__((opencl_local)) uint8_t *Ptr,
dataT Data) noexcept;

template <typename dataT>
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL dataT
__spirv_SubgroupBlockReadINTEL(const __attribute__((opencl_local))
uint16_t *Ptr) noexcept;

template <typename dataT>
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL void
__spirv_SubgroupBlockWriteINTEL(__attribute__((opencl_local)) uint16_t *Ptr,
dataT Data) noexcept;

template <typename dataT>
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL dataT
__spirv_SubgroupBlockReadINTEL(const __attribute__((opencl_local))
uint32_t *Ptr) noexcept;

template <typename dataT>
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL void
__spirv_SubgroupBlockWriteINTEL(__attribute__((opencl_local)) uint32_t *Ptr,
dataT Data) noexcept;

template <typename dataT>
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL dataT
__spirv_SubgroupBlockReadINTEL(const __attribute__((opencl_local))
uint64_t *Ptr) noexcept;

template <typename dataT>
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL void
__spirv_SubgroupBlockWriteINTEL(__attribute__((opencl_local)) uint64_t *Ptr,
dataT Data) noexcept;

template <int W, int rW>
extern __DPCPP_SYCL_EXTERNAL sycl::detail::ap_int<rW>
__spirv_FixedSqrtINTEL(sycl::detail::ap_int<W> a, bool S, int32_t I, int32_t rI,
Expand Down
112 changes: 90 additions & 22 deletions sycl/include/sycl/ext/oneapi/experimental/group_load_store.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@ struct naive_key : detail::compile_time_property_key<detail::PropKind::Naive> {
using value_t = property_value<naive_key>;
};
inline constexpr naive_key::value_t naive;

struct native_local_block_io_key
: detail::compile_time_property_key<detail::PropKind::NativeLocalBlockIO> {
using value_t = property_value<native_local_block_io_key>;
};
inline constexpr native_local_block_io_key::value_t native_local_block_io;

using namespace sycl::detail;
} // namespace detail

Expand Down Expand Up @@ -154,7 +161,6 @@ template <typename BlockInfoTy> struct BlockTypeInfo;
template <typename IteratorT, std::size_t ElementsPerWorkItem, bool Blocked>
struct BlockTypeInfo<BlockInfo<IteratorT, ElementsPerWorkItem, Blocked>> {
using BlockInfoTy = BlockInfo<IteratorT, ElementsPerWorkItem, Blocked>;
static_assert(BlockInfoTy::has_builtin);

using block_type = detail::fixed_width_unsigned<BlockInfoTy::block_size>;

Expand All @@ -163,15 +169,23 @@ struct BlockTypeInfo<BlockInfo<IteratorT, ElementsPerWorkItem, Blocked>> {
typename std::iterator_traits<IteratorT>::reference>>,
std::add_const_t<block_type>, block_type>;

using block_pointer_type = typename detail::DecoratedType<
block_pointer_elem_type, access::address_space::global_space>::type *;
static constexpr auto deduced_address_space =
detail::deduce_AS<std::remove_cv_t<IteratorT>>::value;

using block_pointer_type =
typename detail::DecoratedType<block_pointer_elem_type,
deduced_address_space>::type *;

using block_op_type = std::conditional_t<
BlockInfoTy::num_blocks == 1, block_type,
detail::ConvertToOpenCLType_t<vec<block_type, BlockInfoTy::num_blocks>>>;
};

// Returns either a pointer suitable to use in a block read/write builtin or
// nullptr if some legality conditions aren't satisfied.
// Returns either a pointer decorated with the deduced address space, suitable
// to use in a block read/write builtin, or nullptr if some legality conditions
// aren't satisfied. If deduced address space is generic then returned pointer
// will have generic address space and has to be dynamically casted to global or
// local space before using in a builtin.
template <int RequiredAlign, std::size_t ElementsPerWorkItem,
typename IteratorT, typename Properties>
auto get_block_op_ptr(IteratorT iter, [[maybe_unused]] Properties props) {
Expand Down Expand Up @@ -211,16 +225,17 @@ auto get_block_op_ptr(IteratorT iter, [[maybe_unused]] Properties props) {
bool is_aligned = alignof(value_type) >= RequiredAlign ||
reinterpret_cast<uintptr_t>(iter) % RequiredAlign == 0;

constexpr auto AS = detail::deduce_AS<iter_no_cv>::value;
using block_pointer_type =
typename BlockTypeInfo<BlkInfo>::block_pointer_type;
if constexpr (AS == access::address_space::global_space) {

static constexpr auto deduced_address_space =
BlockTypeInfo<BlkInfo>::deduced_address_space;
if constexpr (deduced_address_space ==
access::address_space::generic_space ||
deduced_address_space ==
access::address_space::global_space ||
deduced_address_space == access::address_space::local_space) {
return is_aligned ? reinterpret_cast<block_pointer_type>(iter) : nullptr;
} else if constexpr (AS == access::address_space::generic_space) {
return is_aligned ? reinterpret_cast<block_pointer_type>(
detail::dynamic_address_cast<
access::address_space::global_space>(iter))
: nullptr;
} else {
return nullptr;
}
Expand Down Expand Up @@ -261,11 +276,37 @@ group_load(Group g, InputIteratorT in_ptr,
// Do optimized load.
using value_type = remove_decoration_t<
typename std::iterator_traits<InputIteratorT>::value_type>;

auto load = __spirv_SubgroupBlockReadINTEL<
typename detail::BlockTypeInfo<detail::BlockInfo<
InputIteratorT, ElementsPerWorkItem, blocked>>::block_op_type>(
ptr);
using block_info = typename detail::BlockTypeInfo<
detail::BlockInfo<InputIteratorT, ElementsPerWorkItem, blocked>>;
static constexpr auto deduced_address_space =
block_info::deduced_address_space;
using block_op_type = typename block_info::block_op_type;

if constexpr (deduced_address_space ==
access::address_space::local_space &&
!props.template has_property<
detail::native_local_block_io_key>())
return group_load(g, in_ptr, out, use_naive{});

block_op_type load;
if constexpr (deduced_address_space ==
access::address_space::generic_space) {
if (auto local_ptr = detail::dynamic_address_cast<
access::address_space::local_space>(ptr)) {
if constexpr (props.template has_property<
detail::native_local_block_io_key>())
load = __spirv_SubgroupBlockReadINTEL<block_op_type>(local_ptr);
else
return group_load(g, in_ptr, out, use_naive{});
} else if (auto global_ptr = detail::dynamic_address_cast<
access::address_space::global_space>(ptr)) {
load = __spirv_SubgroupBlockReadINTEL<block_op_type>(global_ptr);
} else {
return group_load(g, in_ptr, out, use_naive{});
}
} else {
load = __spirv_SubgroupBlockReadINTEL<block_op_type>(ptr);
}

// TODO: accessor_iterator's value_type is weird, so we need
// `std::remove_const_t` below:
Expand Down Expand Up @@ -331,6 +372,16 @@ group_store(Group g, const span<InputT, ElementsPerWorkItem> in,
return group_store(g, in, out_ptr, use_naive{});

if constexpr (!std::is_same_v<std::nullptr_t, decltype(ptr)>) {
using block_info = typename detail::BlockTypeInfo<
detail::BlockInfo<OutputIteratorT, ElementsPerWorkItem, blocked>>;
static constexpr auto deduced_address_space =
block_info::deduced_address_space;
if constexpr (deduced_address_space ==
access::address_space::local_space &&
!props.template has_property<
detail::native_local_block_io_key>())
return group_store(g, in, out_ptr, use_naive{});

// Do optimized store.
std::remove_const_t<remove_decoration_t<
typename std::iterator_traits<OutputIteratorT>::value_type>>
Expand All @@ -341,11 +392,28 @@ group_store(Group g, const span<InputT, ElementsPerWorkItem> in,
values[i] = in[i];
}

__spirv_SubgroupBlockWriteINTEL(
ptr,
sycl::bit_cast<typename detail::BlockTypeInfo<detail::BlockInfo<
OutputIteratorT, ElementsPerWorkItem, blocked>>::block_op_type>(
values));
using block_op_type = typename block_info::block_op_type;
if constexpr (deduced_address_space ==
access::address_space::generic_space) {
if (auto local_ptr = detail::dynamic_address_cast<
access::address_space::local_space>(ptr)) {
if constexpr (props.template has_property<
detail::native_local_block_io_key>())
__spirv_SubgroupBlockWriteINTEL(
local_ptr, sycl::bit_cast<block_op_type>(values));
else
return group_store(g, in, out_ptr, use_naive{});
} else if (auto global_ptr = detail::dynamic_address_cast<
access::address_space::global_space>(ptr)) {
__spirv_SubgroupBlockWriteINTEL(
global_ptr, sycl::bit_cast<block_op_type>(values));
} else {
return group_store(g, in, out_ptr, use_naive{});
}
} else {
__spirv_SubgroupBlockWriteINTEL(ptr,
sycl::bit_cast<block_op_type>(values));
}
}
}
}
Expand Down
3 changes: 2 additions & 1 deletion sycl/include/sycl/ext/oneapi/properties/property.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,9 @@ enum PropKind : uint32_t {
WorkGroupScratchSize = 79,
Restrict = 80,
EventMode = 81,
NativeLocalBlockIO = 82,
// PropKindSize must always be the last value.
PropKindSize = 82,
PropKindSize = 83,
};

template <typename PropertyT> struct PropertyToKind {
Expand Down
79 changes: 66 additions & 13 deletions sycl/test-e2e/GroupAlgorithm/load_store/basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

#include <numeric>

int main() {
using namespace sycl;
using namespace sycl;

template <access::address_space addr_space> int test(queue &q) {
namespace sycl_exp = sycl::ext::oneapi::experimental;

constexpr std::size_t wg_size = 32;
Expand All @@ -16,8 +17,6 @@ int main() {
constexpr std::size_t elems_per_wi = 4;
constexpr std::size_t n = global_size * elems_per_wi;

queue q;

buffer<int, 1> input_buf{n};

{
Expand All @@ -42,8 +41,10 @@ int main() {
accessor store_blocked{store_blocked_buf, cgh};
accessor store_striped{store_striped_buf, cgh};

local_accessor<int, 1> local_acc{wg_size * elems_per_wi, cgh};
cgh.parallel_for(nd_range<1>{global_size, wg_size}, [=](nd_item<1> ndi) {
auto gid = ndi.get_global_id(0);
auto lid = ndi.get_local_id(0);
auto g = ndi.get_group();
auto offset = g.get_group_id(0) * g.get_local_range(0) * elems_per_wi;

Expand All @@ -52,31 +53,76 @@ int main() {
auto blocked = sycl_exp::properties{sycl_exp::data_placement_blocked};
auto striped = sycl_exp::properties{sycl_exp::data_placement_striped};

if constexpr (addr_space == access::address_space::local_space) {
// Copy input to local memory.
for (int i = lid * elems_per_wi; i < lid * elems_per_wi + elems_per_wi;
i++) {
local_acc[i] = input[offset + i];
}
ndi.barrier(access::fence_space::local_space);
}

// default
sycl_exp::group_load(g, input.begin() + offset, span{data});
if constexpr (addr_space == access::address_space::local_space) {
sycl_exp::group_load(g, local_acc.begin(), span{data});
} else {
sycl_exp::group_load(g, input.begin() + offset, span{data});
}
for (int i = 0; i < elems_per_wi; ++i)
load_blocked_default[gid * elems_per_wi + i] = data[i];

// blocked
sycl_exp::group_load(g, input.begin() + offset, span{data}, blocked);
if constexpr (addr_space == access::address_space::local_space) {
sycl_exp::group_load(g, local_acc.begin(), span{data}, blocked);
} else {
sycl_exp::group_load(g, input.begin() + offset, span{data}, blocked);
}
for (int i = 0; i < elems_per_wi; ++i)
load_blocked[gid * elems_per_wi + i] = data[i];

// striped
sycl_exp::group_load(g, input.begin() + offset, span{data}, striped);
if constexpr (addr_space == access::address_space::local_space) {
sycl_exp::group_load(g, local_acc.begin(), span{data}, striped);
} else {
sycl_exp::group_load(g, input.begin() + offset, span{data}, striped);
}
for (int i = 0; i < elems_per_wi; ++i)
load_striped[gid * elems_per_wi + i] = data[i];

// Stores:

std::iota(std::begin(data), std::end(data), gid * elems_per_wi);

sycl_exp::group_store(g, span{data},
store_blocked_default.begin() + offset);
sycl_exp::group_store(g, span{data}, store_blocked.begin() + offset,
blocked);
sycl_exp::group_store(g, span{data}, store_striped.begin() + offset,
striped);
auto copy_local_acc_to_global_output = [&](accessor<int, 1> output) {
for (int i = lid * elems_per_wi; i < lid * elems_per_wi + elems_per_wi;
i++) {
output[offset + i] = local_acc[i];
}
};

if constexpr (addr_space == access::address_space::local_space) {
sycl_exp::group_store(g, span{data}, local_acc.begin());
copy_local_acc_to_global_output(store_blocked_default);
} else {
sycl_exp::group_store(g, span{data},
store_blocked_default.begin() + offset);
}

if constexpr (addr_space == access::address_space::local_space) {
sycl_exp::group_store(g, span{data}, local_acc.begin(), blocked);
copy_local_acc_to_global_output(store_blocked);
} else {
sycl_exp::group_store(g, span{data}, store_blocked.begin() + offset,
blocked);
}

if constexpr (addr_space == access::address_space::local_space) {
sycl_exp::group_store(g, span{data}, local_acc.begin(), striped);
copy_local_acc_to_global_output(store_striped);
} else {
sycl_exp::group_store(g, span{data}, store_striped.begin() + offset,
striped);
}
});
});

Expand Down Expand Up @@ -111,3 +157,10 @@ int main() {

return 0;
}

int main() {
queue q;
test<access::address_space::global_space>(q);
test<access::address_space::local_space>(q);
return 0;
}
Loading

0 comments on commit 5edcf74

Please sign in to comment.