-
Notifications
You must be signed in to change notification settings - Fork 70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement groupreduce API #559
base: main
Are you sure you want to change the base?
Conversation
Benchmark Results
Benchmark PlotsA plot of the benchmark results have been uploaded as an artifact to the workflow run for this PR. |
UPD: #218 (comment) |
end | ||
|
||
function groupreduce_testsuite(backend, AT) | ||
@testset "@groupreduce" begin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@testset "@groupreduce" begin | |
return @testset "@groupreduce" begin |
end | ||
|
||
function groupreduce_testsuite(backend, AT) | ||
@testset "@groupreduce" begin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@testset "@groupreduce" begin | |
return @testset "@groupreduce" begin |
groupsizes = "$backend" == "oneAPIBackend" ? | ||
(256,) : | ||
(256, 512, 1024) | ||
@testset "@groupreduce" begin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@testset "@groupreduce" begin | |
return @testset "@groupreduce" begin |
src/reduce.jl
Outdated
- `algo` specifies which reduction algorithm to use: | ||
- `Reduction.thread`: | ||
Perform thread group reduction (requires `groupsize * sizeof(T)` bytes of shared memory). | ||
Available accross all backends. | ||
- `Reduction.warp`: | ||
Perform warp group reduction (requires `32 * sizeof(T)` bytes of shared memory). | ||
Potentially faster, since requires fewer writes to shared memory. | ||
To query if backend supports warp reduction, use `supports_warp_reduction(backend)`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is that needed? Shouldn't the backend go and use warp reductions if it can?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm now doing an auto-selection of the algorithm based on device function __supports_warp_reduction()
.
while s > 0x00 | ||
if (local_idx - 0x01) < s | ||
other_idx = local_idx + s | ||
if other_idx ≤ groupsize | ||
@inbounds storage[local_idx] = op(storage[local_idx], storage[other_idx]) | ||
end | ||
end | ||
@synchronize() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(I assume this code is GPU only anyways)
src/reduce.jl
Outdated
macro shfl_down(val, offset) | ||
return quote | ||
$__shfl_down($(esc(val)), $(esc(offset))) | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it isn't user-facing or needs special CPU handling you don't need to introduce a new macro
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, removed macro.
test/groupreduce.jl
Outdated
@kernel function groupreduce_1!(y, x, op, neutral, algo) | ||
i = @index(Global) | ||
val = i > length(x) ? neutral : x[i] | ||
res = @groupreduce(op, val, neutral, algo) | ||
i == 1 && (y[1] = res) | ||
end | ||
|
||
@kernel function groupreduce_2!(y, x, op, neutral, algo, ::Val{groupsize}) where {groupsize} | ||
i = @index(Global) | ||
val = i > length(x) ? neutral : x[i] | ||
res = @groupreduce(op, val, neutral, algo, groupsize) | ||
i == 1 && (y[1] = res) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These need to be cpu=false
since you are using non-top-level @synchronize
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! This currently doesn't fully work since calling a function with a __ctx__
argument is GPU only.
#558 is also rearing it's ugly head. I suspect we will need a macro-free kernel language in KA for writing this correctly.
@kernel cpu=false function groupreduce_1!(y, x, op, neutral) | ||
i = @index(Global) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kernel cpu=false function groupreduce_1!(y, x, op, neutral) | |
i = @index(Global) | |
@kernel cpu = false function groupreduce_1!(y, x, op, neutral) | |
@kernel cpu = false function groupreduce_2!(y, x, op, neutral, ::Val{groupsize}) where {groupsize} |
@kernel cpu=false function groupreduce_1!(y, x, op, neutral) | ||
i = @index(Global) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kernel cpu=false function groupreduce_1!(y, x, op, neutral) | |
i = @index(Global) | |
@kernel cpu = false function groupreduce_1!(y, x, op, neutral) | |
@kernel cpu = false function groupreduce_2!(y, x, op, neutral, ::Val{groupsize}) where {groupsize} |
groupsizes = "$backend" == "oneAPIBackend" ? | ||
(256,) : | ||
(256, 512, 1024) | ||
@testset "@groupreduce" begin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@testset "@groupreduce" begin | |
return @testset "@groupreduce" begin |
@kernel cpu=false function groupreduce_1!(y, x, op, neutral) | ||
i = @index(Global) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kernel cpu=false function groupreduce_1!(y, x, op, neutral) | |
i = @index(Global) | |
@kernel cpu = false function groupreduce_1!(y, x, op, neutral) | |
@kernel cpu = false function groupreduce_2!(y, x, op, neutral, ::Val{groupsize}) where {groupsize} |
groupsizes = "$backend" == "oneAPIBackend" ? | ||
(256,) : | ||
(256, 512, 1024) | ||
@testset "@groupreduce" begin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@testset "@groupreduce" begin | |
return @testset "@groupreduce" begin |
Implement reduction API. Supports two types of algorithms:
groupsize
, no bank conflict, no divergence.shlf_down
within warps: uses shmem of length32
, reduction within warps storing results in shmem, followed by final warp reduction using values stored in shmem. Backends are required to only implementshlf_down
intrinsic which AMDGPU/CUDA/Metal have (no sure about other backends).KA.__supports_warp_reduction()
.