Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement groupreduce API #559

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
@uniform
@groupsize
@ndrange
synchronize
allocate
@groupreduce
```

## Host language

```@docs
synchronize
allocate
KernelAbstractions.zeros
```

Expand Down
2 changes: 2 additions & 0 deletions src/KernelAbstractions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,8 @@ function __fake_compiler_job end
# - LoopInfo
###

include("reduce.jl")

include("extras/extras.jl")

include("reflection.jl")
Expand Down
130 changes: 130 additions & 0 deletions src/reduce.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
export @groupreduce, Reduction

module Reduction
const thread = Val(:thread)
const warp = Val(:warp)
end

"""
@groupreduce op val neutral algo [groupsize]

Perform group reduction of `val` using `op`.

# Arguments

- `algo` specifies which reduction algorithm to use:
- `Reduction.thread`:
Perform thread group reduction (requires `groupsize * sizeof(T)` bytes of shared memory).
Available accross all backends.
- `Reduction.warp`:
Perform warp group reduction (requires `32 * sizeof(T)` bytes of shared memory).
Potentially faster, since requires fewer writes to shared memory.
To query if backend supports warp reduction, use `supports_warp_reduction(backend)`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is that needed? Shouldn't the backend go and use warp reductions if it can?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm now doing an auto-selection of the algorithm based on device function __supports_warp_reduction().


- `neutral` should be a neutral w.r.t. `op`, such that `op(neutral, x) == x`.

- `groupsize` specifies size of the workgroup.
If a kernel does not specifies `groupsize` statically, then it is required to
provide `groupsize`.
Also can be used to perform reduction accross first `groupsize` threads
(if `groupsize < @groupsize()`).

# Returns

Result of the reduction.
"""
macro groupreduce(op, val, neutral, algo)
return quote
__groupreduce(
$(esc(:__ctx__)),
$(esc(op)),
$(esc(val)),
$(esc(neutral)),
Val(prod($groupsize($(esc(:__ctx__))))),
$(esc(algo)),
)
end
end

macro groupreduce(op, val, neutral, algo, groupsize)
return quote
__groupreduce(
$(esc(:__ctx__)),
$(esc(op)),
$(esc(val)),
$(esc(neutral)),
Val($(esc(groupsize))),
$(esc(algo)),
)
end
end

function __groupreduce(__ctx__, op, val::T, neutral::T, ::Val{groupsize}, ::Val{:thread}) where {T, groupsize}
storage = @localmem T groupsize

local_idx = @index(Local)
@inbounds local_idx ≤ groupsize && (storage[local_idx] = val)
@synchronize()

s::UInt64 = groupsize ÷ 0x02
while s > 0x00
if (local_idx - 0x01) < s
other_idx = local_idx + s
if other_idx ≤ groupsize
@inbounds storage[local_idx] = op(storage[local_idx], storage[other_idx])
end
end
@synchronize()
Comment on lines +45 to +52
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently this is not legal.

#262 might need to wait until #556

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I assume this code is GPU only anyways)

s >>= 0x01
end

if local_idx == 0x01
@inbounds val = storage[local_idx]
end
return val
end

# Warp groupreduce.

macro shfl_down(val, offset)
return quote
$__shfl_down($(esc(val)), $(esc(offset)))
end
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it isn't user-facing or needs special CPU handling you don't need to introduce a new macro

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, removed macro.


# Backends should implement these two.
function __shfl_down end
supports_warp_reduction(::Backend) = false

# Assume warp is 32 lanes.
const __warpsize = UInt32(32)
# Maximum number of warps (for a groupsize = 1024).
const __warp_bins = UInt32(32)

@inline function __warp_reduce(val, op)
offset::UInt32 = __warpsize ÷ 0x02
while offset > 0x00
val = op(val, @shfl_down(val, offset))
offset >>= 0x01
end
return val
end

function __groupreduce(__ctx__, op, val::T, neutral::T, ::Val{groupsize}, ::Val{:warp}) where {T, groupsize}
storage = @localmem T __warp_bins

local_idx = @index(Local)
lane = (local_idx - 0x01) % __warpsize + 0x01
warp_id = (local_idx - 0x01) ÷ __warpsize + 0x01

# Each warp performs a reduction and writes results into its own bin in `storage`.
val = __warp_reduce(val, op)
@inbounds lane == 0x01 && (storage[warp_id] = val)
@synchronize()

# Final reduction of the `storage` on the first warp.
within_storage = (local_idx - 0x01) < groupsize ÷ __warpsize
@inbounds val = within_storage ? storage[lane] : neutral
warp_id == 0x01 && (val = __warp_reduce(val, op))
return val
end
48 changes: 48 additions & 0 deletions test/groupreduce.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
@kernel function groupreduce_1!(y, x, op, neutral, algo)
i = @index(Global)
val = i > length(x) ? neutral : x[i]
res = @groupreduce(op, val, neutral, algo)
i == 1 && (y[1] = res)
end

@kernel function groupreduce_2!(y, x, op, neutral, algo, ::Val{groupsize}) where {groupsize}
i = @index(Global)
val = i > length(x) ? neutral : x[i]
res = @groupreduce(op, val, neutral, algo, groupsize)
i == 1 && (y[1] = res)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These need to be cpu=false since you are using non-top-level @synchronize

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


function groupreduce_testsuite(backend, AT)
@testset "@groupreduce" begin
pxl-th marked this conversation as resolved.
Show resolved Hide resolved
pxl-th marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@testset "@groupreduce" begin
return @testset "@groupreduce" begin

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@testset "@groupreduce" begin
return @testset "@groupreduce" begin

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@testset "@groupreduce" begin
return @testset "@groupreduce" begin

pxl-th marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@testset "@groupreduce" begin
return @testset "@groupreduce" begin

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@testset "@groupreduce" begin
return @testset "@groupreduce" begin

@testset "thread reduction T=$T, n=$n" for T in (Float16, Float32, Int32, Int64), n in (256, 512, 1024)
x = AT(ones(T, n))
y = AT(zeros(T, 1))
pxl-th marked this conversation as resolved.
Show resolved Hide resolved

groupreduce_1!(backend(), n)(y, x, +, zero(T), Reduction.thread; ndrange = n)
@test Array(y)[1] == n

groupreduce_2!(backend())(y, x, +, zero(T), Reduction.thread, Val(128); ndrange = n)
@test Array(y)[1] == 128

groupreduce_2!(backend())(y, x, +, zero(T), Reduction.thread, Val(64); ndrange = n)
@test Array(y)[1] == 64
end

warp_reduction = KernelAbstractions.supports_warp_reduction(backend())
if warp_reduction
@testset "warp reduction T=$T, n=$n" for T in (Float16, Float32, Int32, Int64), n in (256, 512, 1024)

x = AT(ones(T, n))
y = AT(zeros(T, 1))
groupreduce_1!(backend(), n)(y, x, +, zero(T), Reduction.warp; ndrange = n)
@test Array(y)[1] == n

groupreduce_2!(backend())(y, x, +, zero(T), Reduction.warp, Val(128); ndrange = n)
@test Array(y)[1] == 128

groupreduce_2!(backend())(y, x, +, zero(T), Reduction.warp, Val(64); ndrange = n)
@test Array(y)[1] == 64
end
end
end
end
8 changes: 8 additions & 0 deletions test/testsuite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ include("reflection.jl")
include("examples.jl")
include("convert.jl")
include("specialfunctions.jl")
include("groupreduce.jl")

function testsuite(backend, backend_str, backend_mod, AT, DAT; skip_tests = Set{String}())
@conditional_testset "Unittests" skip_tests begin
Expand Down Expand Up @@ -92,6 +93,13 @@ function testsuite(backend, backend_str, backend_mod, AT, DAT; skip_tests = Set{
examples_testsuite(backend_str)
end

# TODO @index(Local) only works as a top-level expression on CPU.
if backend != CPU
@conditional_testset "@groupreduce" skip_tests begin
groupreduce_testsuite(backend, AT)
end
end

return
end

Expand Down
Loading