Skip to content

Commit

Permalink
Update for ROCm 6.1.0
Browse files Browse the repository at this point in the history
A few functions used in Scan have been deprecated in 6.1.0.
  • Loading branch information
WeiqunZhang committed Apr 17, 2024
1 parent 9799e16 commit 294ab2b
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions Src/Base/AMReX_Scan.H
Original file line number Diff line number Diff line change
Expand Up @@ -641,13 +641,33 @@ T PrefixSum (N n, FIN const& fin, FOUT const& fout, TYPE, RetSum a_ret_sum = ret
using ScanTileState = rocprim::detail::lookback_scan_state<T>;
using OrderedBlockId = rocprim::detail::ordered_block_id<unsigned int>;

#if (defined(HIP_VERSION_MAJOR) && (HIP_VERSION_MAJOR < 6)) || \
(defined(HIP_VERSION_MAJOR) && (HIP_VERSION_MAJOR == 6) && \
defined(HIP_VERSION_MINOR) && (HIP_VERSION_MINOR == 0))

std::size_t nbytes_tile_state = rocprim::detail::align_size
(ScanTileState::get_storage_size(nblocks));
std::size_t nbytes_block_id = OrderedBlockId::get_storage_size();

auto dp = (char*)(The_Arena()->alloc(nbytes_tile_state+nbytes_block_id));

ScanTileState tile_state = ScanTileState::create(dp, nblocks);

#else

std::size_t nbytes_tile_state;
AMREX_HIP_SAFE_CALL(ScanTileState::get_storage_size(nblocks, stream, nbytes_tile_state));
nbytes_tile_state = rocprim::detail::align_size(nbytes_tile_state);

std::size_t nbytes_block_id = OrderedBlockId::get_storage_size();

auto dp = (char*)(The_Arena()->alloc(nbytes_tile_state+nbytes_block_id));

ScanTileState tile_state;
AMREX_HIP_SAFE_CALL(ScanTileState::create(tile_state, dp, nblocks, stream));

#endif

auto ordered_block_id = OrderedBlockId::create
(reinterpret_cast<OrderedBlockId::id_type*>(dp + nbytes_tile_state));

Expand Down

0 comments on commit 294ab2b

Please sign in to comment.