Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement groupreduce API #559

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
Draft

Implement groupreduce API #559

wants to merge 11 commits into from

Conversation

pxl-th
Copy link
Member

@pxl-th pxl-th commented Jan 30, 2025

Implement reduction API. Supports two types of algorithms:

  • thread: reduction performed by threads: uses shmem of length groupsize, no bank conflict, no divergence.
  • warp: reduction performed by shlf_down within warps: uses shmem of length 32, reduction within warps storing results in shmem, followed by final warp reduction using values stored in shmem. Backends are required to only implement shlf_down intrinsic which AMDGPU/CUDA/Metal have (no sure about other backends).
  • query function to check if backend supports warp reduction KA.__supports_warp_reduction().
res = @groupreduce op val neutral
  • Optionally limit number of threads that participate in reduction.
res = @groupreduce op val neutral 128 # first 128 threads will perform reduction

src/reduce.jl Outdated Show resolved Hide resolved
src/reduce.jl Outdated Show resolved Hide resolved
src/reduce.jl Outdated Show resolved Hide resolved
src/reduce.jl Outdated Show resolved Hide resolved
src/reduce.jl Outdated Show resolved Hide resolved
src/reduce.jl Outdated Show resolved Hide resolved
test/groupreduce.jl Outdated Show resolved Hide resolved
test/groupreduce.jl Outdated Show resolved Hide resolved
Copy link
Contributor

github-actions bot commented Jan 30, 2025

Benchmark Results

main 618c840... main/618c840f5251a4...
saxpy/default/Float16/1024 0.759 ± 0.0071 μs 0.733 ± 0.0068 μs 1.04
saxpy/default/Float16/1048576 0.173 ± 0.0091 ms 0.175 ± 0.0089 ms 0.988
saxpy/default/Float16/16384 3.38 ± 0.032 μs 3.37 ± 0.056 μs 1
saxpy/default/Float16/2048 0.938 ± 0.012 μs 0.905 ± 0.011 μs 1.04
saxpy/default/Float16/256 0.612 ± 0.0058 μs 0.589 ± 0.0057 μs 1.04
saxpy/default/Float16/262144 0.0444 ± 0.00051 ms 0.0443 ± 0.00049 ms 1
saxpy/default/Float16/32768 6.06 ± 0.048 μs 6.06 ± 0.11 μs 1
saxpy/default/Float16/4096 1.35 ± 0.025 μs 1.31 ± 0.022 μs 1.03
saxpy/default/Float16/512 0.677 ± 0.0066 μs 0.648 ± 0.0069 μs 1.04
saxpy/default/Float16/64 0.582 ± 0.0057 μs 0.559 ± 0.0052 μs 1.04
saxpy/default/Float16/65536 11.8 ± 0.13 μs 11.8 ± 0.2 μs 1
saxpy/default/Float32/1024 0.646 ± 0.0094 μs 0.635 ± 0.0099 μs 1.02
saxpy/default/Float32/1048576 0.233 ± 0.015 ms 0.233 ± 0.02 ms 1
saxpy/default/Float32/16384 2.86 ± 0.29 μs 2.78 ± 0.14 μs 1.03
saxpy/default/Float32/2048 0.757 ± 0.026 μs 0.769 ± 0.048 μs 0.984
saxpy/default/Float32/256 0.584 ± 0.008 μs 0.573 ± 0.0071 μs 1.02
saxpy/default/Float32/262144 0.0578 ± 0.0034 ms 0.0564 ± 0.0035 ms 1.03
saxpy/default/Float32/32768 5.43 ± 0.56 μs 5.35 ± 0.31 μs 1.01
saxpy/default/Float32/4096 1.12 ± 0.051 μs 1.13 ± 0.083 μs 0.987
saxpy/default/Float32/512 0.611 ± 0.0086 μs 0.604 ± 0.0065 μs 1.01
saxpy/default/Float32/64 0.565 ± 0.0052 μs 0.557 ± 0.0056 μs 1.01
saxpy/default/Float32/65536 12.6 ± 0.91 μs 12.4 ± 0.85 μs 1.01
saxpy/default/Float64/1024 0.752 ± 0.073 μs 0.765 ± 0.059 μs 0.983
saxpy/default/Float64/1048576 0.486 ± 0.025 ms 0.501 ± 0.028 ms 0.971
saxpy/default/Float64/16384 5.4 ± 0.46 μs 5.37 ± 0.36 μs 1.01
saxpy/default/Float64/2048 1.17 ± 0.12 μs 1.14 ± 0.088 μs 1.02
saxpy/default/Float64/256 0.574 ± 0.0076 μs 0.593 ± 0.0079 μs 0.968
saxpy/default/Float64/262144 0.114 ± 0.0083 ms 0.115 ± 0.01 ms 0.997
saxpy/default/Float64/32768 12.5 ± 0.74 μs 12.7 ± 0.8 μs 0.991
saxpy/default/Float64/4096 1.72 ± 0.16 μs 1.69 ± 0.11 μs 1.02
saxpy/default/Float64/512 0.625 ± 0.014 μs 0.647 ± 0.012 μs 0.967
saxpy/default/Float64/64 0.549 ± 0.0066 μs 0.564 ± 0.0059 μs 0.974
saxpy/default/Float64/65536 28.7 ± 1.4 μs 28.6 ± 1.9 μs 1
saxpy/static workgroup=(1024,)/Float16/1024 2.2 ± 0.026 μs 2.17 ± 0.026 μs 1.01
saxpy/static workgroup=(1024,)/Float16/1048576 0.159 ± 0.0088 ms 0.16 ± 0.008 ms 0.994
saxpy/static workgroup=(1024,)/Float16/16384 4.47 ± 0.13 μs 4.49 ± 0.14 μs 0.996
saxpy/static workgroup=(1024,)/Float16/2048 2.38 ± 0.029 μs 2.36 ± 0.026 μs 1.01
saxpy/static workgroup=(1024,)/Float16/256 2.83 ± 0.034 μs 2.8 ± 0.038 μs 1.01
saxpy/static workgroup=(1024,)/Float16/262144 0.0425 ± 0.0015 ms 0.0426 ± 0.0014 ms 0.997
saxpy/static workgroup=(1024,)/Float16/32768 6.86 ± 0.21 μs 6.93 ± 0.27 μs 0.99
saxpy/static workgroup=(1024,)/Float16/4096 2.71 ± 0.037 μs 2.68 ± 0.037 μs 1.01
saxpy/static workgroup=(1024,)/Float16/512 3.28 ± 0.036 μs 3.24 ± 0.034 μs 1.01
saxpy/static workgroup=(1024,)/Float16/64 2.53 ± 0.21 μs 2.5 ± 0.21 μs 1.01
saxpy/static workgroup=(1024,)/Float16/65536 12.6 ± 0.32 μs 12.7 ± 0.39 μs 0.988
saxpy/static workgroup=(1024,)/Float32/1024 2.24 ± 0.032 μs 2.24 ± 0.033 μs 1
saxpy/static workgroup=(1024,)/Float32/1048576 0.243 ± 0.016 ms 0.237 ± 0.021 ms 1.03
saxpy/static workgroup=(1024,)/Float32/16384 4.43 ± 0.29 μs 4.65 ± 0.69 μs 0.953
saxpy/static workgroup=(1024,)/Float32/2048 2.39 ± 0.053 μs 2.38 ± 0.033 μs 1.01
saxpy/static workgroup=(1024,)/Float32/256 2.69 ± 0.058 μs 2.69 ± 0.051 μs 1
saxpy/static workgroup=(1024,)/Float32/262144 0.0604 ± 0.0037 ms 0.0599 ± 0.004 ms 1.01
saxpy/static workgroup=(1024,)/Float32/32768 7.47 ± 0.43 μs 7.66 ± 1.1 μs 0.975
saxpy/static workgroup=(1024,)/Float32/4096 2.68 ± 0.084 μs 2.65 ± 0.058 μs 1.01
saxpy/static workgroup=(1024,)/Float32/512 2.7 ± 0.036 μs 2.69 ± 0.031 μs 1
saxpy/static workgroup=(1024,)/Float32/64 2.72 ± 5.5 μs 2.71 ± 5.1 μs 1.01
saxpy/static workgroup=(1024,)/Float32/65536 15.5 ± 1 μs 15.8 ± 1.2 μs 0.977
saxpy/static workgroup=(1024,)/Float64/1024 2.34 ± 0.07 μs 2.31 ± 0.047 μs 1.01
saxpy/static workgroup=(1024,)/Float64/1048576 0.521 ± 0.027 ms 0.5 ± 0.035 ms 1.04
saxpy/static workgroup=(1024,)/Float64/16384 7.28 ± 0.38 μs 7.3 ± 0.78 μs 0.998
saxpy/static workgroup=(1024,)/Float64/2048 2.63 ± 0.083 μs 2.59 ± 0.06 μs 1.01
saxpy/static workgroup=(1024,)/Float64/256 2.65 ± 0.052 μs 2.64 ± 0.056 μs 1
saxpy/static workgroup=(1024,)/Float64/262144 0.118 ± 0.0073 ms 0.118 ± 0.012 ms 1
saxpy/static workgroup=(1024,)/Float64/32768 15.3 ± 1.5 μs 15.6 ± 1.4 μs 0.986
saxpy/static workgroup=(1024,)/Float64/4096 3.19 ± 0.22 μs 3.28 ± 0.23 μs 0.972
saxpy/static workgroup=(1024,)/Float64/512 2.67 ± 0.057 μs 2.67 ± 0.062 μs 1
saxpy/static workgroup=(1024,)/Float64/64 2.62 ± 0.061 μs 2.61 ± 0.062 μs 1.01
saxpy/static workgroup=(1024,)/Float64/65536 31.2 ± 2.4 μs 31.2 ± 2 μs 1
time_to_load 0.319 ± 0.0037 s 0.32 ± 0.0028 s 0.998

Benchmark Plots

A plot of the benchmark results have been uploaded as an artifact to the workflow run for this PR.
Go to "Actions"->"Benchmark a pull request"->[the most recent run]->"Artifacts" (at the bottom).

@pxl-th
Copy link
Member Author

pxl-th commented Jan 30, 2025

@vchuravy not sure about CPU errors (regarding @index(Local)). Any idea?

UPD: #218 (comment)

end

function groupreduce_testsuite(backend, AT)
@testset "@groupreduce" begin
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@testset "@groupreduce" begin
return @testset "@groupreduce" begin

end

function groupreduce_testsuite(backend, AT)
@testset "@groupreduce" begin
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@testset "@groupreduce" begin
return @testset "@groupreduce" begin

groupsizes = "$backend" == "oneAPIBackend" ?
(256,) :
(256, 512, 1024)
@testset "@groupreduce" begin
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@testset "@groupreduce" begin
return @testset "@groupreduce" begin

src/reduce.jl Outdated
Comment on lines 15 to 22
- `algo` specifies which reduction algorithm to use:
- `Reduction.thread`:
Perform thread group reduction (requires `groupsize * sizeof(T)` bytes of shared memory).
Available accross all backends.
- `Reduction.warp`:
Perform warp group reduction (requires `32 * sizeof(T)` bytes of shared memory).
Potentially faster, since requires fewer writes to shared memory.
To query if backend supports warp reduction, use `supports_warp_reduction(backend)`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is that needed? Shouldn't the backend go and use warp reductions if it can?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm now doing an auto-selection of the algorithm based on device function __supports_warp_reduction().

Comment on lines +70 to +77
while s > 0x00
if (local_idx - 0x01) < s
other_idx = local_idx + s
if other_idx ≤ groupsize
@inbounds storage[local_idx] = op(storage[local_idx], storage[other_idx])
end
end
@synchronize()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently this is not legal.

#262 might need to wait until #556

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I assume this code is GPU only anyways)

src/reduce.jl Outdated
Comment on lines 89 to 93
macro shfl_down(val, offset)
return quote
$__shfl_down($(esc(val)), $(esc(offset)))
end
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it isn't user-facing or needs special CPU handling you don't need to introduce a new macro

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, removed macro.

Comment on lines 1 to 13
@kernel function groupreduce_1!(y, x, op, neutral, algo)
i = @index(Global)
val = i > length(x) ? neutral : x[i]
res = @groupreduce(op, val, neutral, algo)
i == 1 && (y[1] = res)
end

@kernel function groupreduce_2!(y, x, op, neutral, algo, ::Val{groupsize}) where {groupsize}
i = @index(Global)
val = i > length(x) ? neutral : x[i]
res = @groupreduce(op, val, neutral, algo, groupsize)
i == 1 && (y[1] = res)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These need to be cpu=false since you are using non-top-level @synchronize

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Member

@vchuravy vchuravy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! This currently doesn't fully work since calling a function with a __ctx__ argument is GPU only.

#558 is also rearing it's ugly head. I suspect we will need a macro-free kernel language in KA for writing this correctly.

Comment on lines +1 to +2
@kernel cpu=false function groupreduce_1!(y, x, op, neutral)
i = @index(Global)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@kernel cpu=false function groupreduce_1!(y, x, op, neutral)
i = @index(Global)
@kernel cpu = false function groupreduce_1!(y, x, op, neutral)
@kernel cpu = false function groupreduce_2!(y, x, op, neutral, ::Val{groupsize}) where {groupsize}

test/groupreduce.jl Show resolved Hide resolved
@pxl-th pxl-th requested a review from vchuravy February 3, 2025 22:56
Comment on lines +1 to +2
@kernel cpu=false function groupreduce_1!(y, x, op, neutral)
i = @index(Global)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@kernel cpu=false function groupreduce_1!(y, x, op, neutral)
i = @index(Global)
@kernel cpu = false function groupreduce_1!(y, x, op, neutral)
@kernel cpu = false function groupreduce_2!(y, x, op, neutral, ::Val{groupsize}) where {groupsize}

groupsizes = "$backend" == "oneAPIBackend" ?
(256,) :
(256, 512, 1024)
@testset "@groupreduce" begin
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@testset "@groupreduce" begin
return @testset "@groupreduce" begin

Comment on lines +1 to +2
@kernel cpu=false function groupreduce_1!(y, x, op, neutral)
i = @index(Global)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@kernel cpu=false function groupreduce_1!(y, x, op, neutral)
i = @index(Global)
@kernel cpu = false function groupreduce_1!(y, x, op, neutral)
@kernel cpu = false function groupreduce_2!(y, x, op, neutral, ::Val{groupsize}) where {groupsize}

groupsizes = "$backend" == "oneAPIBackend" ?
(256,) :
(256, 512, 1024)
@testset "@groupreduce" begin
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@testset "@groupreduce" begin
return @testset "@groupreduce" begin

@pxl-th pxl-th marked this pull request as draft February 5, 2025 22:48
@vchuravy vchuravy mentioned this pull request Feb 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants