Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement groupreduce API #559

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
@uniform
@groupsize
@ndrange
synchronize
allocate
@groupreduce
```

## Host language

```@docs
synchronize
allocate
KernelAbstractions.zeros
```

Expand Down
2 changes: 2 additions & 0 deletions src/KernelAbstractions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,8 @@ function __fake_compiler_job end
# - LoopInfo
###

include("reduce.jl")

include("extras/extras.jl")

include("reflection.jl")
Expand Down
140 changes: 140 additions & 0 deletions src/reduce.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
export @groupreduce

module Reduction
const thread = Val(:thread)
const warp = Val(:warp)
end

"""
@groupreduce op val neutral [groupsize]

Perform group reduction of `val` using `op`.
If backend supports warp reduction, it will use it instead of thread reduction.

# Arguments

- `neutral` should be a neutral w.r.t. `op`, such that `op(neutral, x) == x`.

- `groupsize` specifies size of the workgroup.
If a kernel does not specifies `groupsize` statically, then it is required to
provide `groupsize`.
Also can be used to perform reduction accross first `groupsize` threads
(if `groupsize < @groupsize()`).

# Returns

Result of the reduction.
"""
macro groupreduce(op, val, neutral)
return quote
if __supports_warp_reduction()
__groupreduce(
$(esc(:__ctx__)),
$(esc(op)),
$(esc(val)),
$(esc(neutral)),
Val(prod($groupsize($(esc(:__ctx__))))),
$(esc(Reduction.warp)),
)
else
__groupreduce(
$(esc(:__ctx__)),
$(esc(op)),
$(esc(val)),
$(esc(neutral)),
Val(prod($groupsize($(esc(:__ctx__))))),
$(esc(Reduction.thread)),
)
end
end
end

macro groupreduce(op, val, neutral, groupsize)
return quote
if __supports_warp_reduction()
__groupreduce(
$(esc(:__ctx__)),
$(esc(op)),
$(esc(val)),
$(esc(neutral)),
Val($(esc(groupsize))),
$(esc(Reduction.warp)),
)
else
__groupreduce(
$(esc(:__ctx__)),
$(esc(op)),
$(esc(val)),
$(esc(neutral)),
Val($(esc(groupsize))),
$(esc(Reduction.thread)),
)
end
end
end

function __groupreduce(__ctx__, op, val::T, neutral::T, ::Val{groupsize}, ::Val{:thread}) where {T, groupsize}
storage = @localmem T groupsize

local_idx = @index(Local)
@inbounds local_idx ≤ groupsize && (storage[local_idx] = val)
@synchronize()

s::UInt64 = groupsize ÷ 0x02
while s > 0x00
if (local_idx - 0x01) < s
other_idx = local_idx + s
if other_idx ≤ groupsize
@inbounds storage[local_idx] = op(storage[local_idx], storage[other_idx])
end
end
@synchronize()
Comment on lines +45 to +52
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently this is not legal.

#262 might need to wait until #556

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I assume this code is GPU only anyways)

s >>= 0x01
end

if local_idx == 0x01
@inbounds val = storage[local_idx]
end
return val
end

# Warp groupreduce.

# NOTE: Backends should implement these two device functions (with `@device_override`).
function __shfl_down end
function __supports_warp_reduction()
return false
end

# Assume warp is 32 lanes.
const __warpsize = UInt32(32)
# Maximum number of warps (for a groupsize = 1024).
const __warp_bins = UInt32(32)

@inline function __warp_reduce(val, op)
offset::UInt32 = __warpsize ÷ 0x02
while offset > 0x00
val = op(val, __shfl_down(val, offset))
offset >>= 0x01
end
return val
end

function __groupreduce(__ctx__, op, val::T, neutral::T, ::Val{groupsize}, ::Val{:warp}) where {T, groupsize}
storage = @localmem T __warp_bins

local_idx = @index(Local)
lane = (local_idx - 0x01) % __warpsize + 0x01
warp_id = (local_idx - 0x01) ÷ __warpsize + 0x01

# Each warp performs a reduction and writes results into its own bin in `storage`.
val = __warp_reduce(val, op)
@inbounds lane == 0x01 && (storage[warp_id] = val)
@synchronize()

# Final reduction of the `storage` on the first warp.
within_storage = (local_idx - 0x01) < groupsize ÷ __warpsize
@inbounds val = within_storage ? storage[lane] : neutral
warp_id == 0x01 && (val = __warp_reduce(val, op))
return val
end
35 changes: 35 additions & 0 deletions test/groupreduce.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
@kernel cpu=false function groupreduce_1!(y, x, op, neutral)
i = @index(Global)
Comment on lines +1 to +2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@kernel cpu=false function groupreduce_1!(y, x, op, neutral)
i = @index(Global)
@kernel cpu = false function groupreduce_1!(y, x, op, neutral)
@kernel cpu = false function groupreduce_2!(y, x, op, neutral, ::Val{groupsize}) where {groupsize}

Comment on lines +1 to +2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@kernel cpu=false function groupreduce_1!(y, x, op, neutral)
i = @index(Global)
@kernel cpu = false function groupreduce_1!(y, x, op, neutral)
@kernel cpu = false function groupreduce_2!(y, x, op, neutral, ::Val{groupsize}) where {groupsize}

Comment on lines +1 to +2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@kernel cpu=false function groupreduce_1!(y, x, op, neutral)
i = @index(Global)
@kernel cpu = false function groupreduce_1!(y, x, op, neutral)
@kernel cpu = false function groupreduce_2!(y, x, op, neutral, ::Val{groupsize}) where {groupsize}

val = i > length(x) ? neutral : x[i]
res = @groupreduce(op, val, neutral)
i == 1 && (y[1] = res)
end

@kernel cpu=false function groupreduce_2!(y, x, op, neutral, ::Val{groupsize}) where {groupsize}
i = @index(Global)
val = i > length(x) ? neutral : x[i]
res = @groupreduce(op, val, neutral, groupsize)
i == 1 && (y[1] = res)
end

function groupreduce_testsuite(backend, AT)
# TODO should be a better way of querying max groupsize
groupsizes = "$backend" == "oneAPIBackend" ?
(256,) :
(256, 512, 1024)
@testset "@groupreduce" begin
pxl-th marked this conversation as resolved.
Show resolved Hide resolved
pxl-th marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@testset "@groupreduce" begin
return @testset "@groupreduce" begin

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@testset "@groupreduce" begin
return @testset "@groupreduce" begin

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@testset "@groupreduce" begin
return @testset "@groupreduce" begin

pxl-th marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@testset "@groupreduce" begin
return @testset "@groupreduce" begin

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@testset "@groupreduce" begin
return @testset "@groupreduce" begin

@testset "T=$T, n=$n" for T in (Float16, Float32, Float64, Int16, Int32, Int64), n in groupsizes
x = AT(ones(T, n))
y = AT(zeros(T, 1))

groupreduce_1!(backend(), n)(y, x, +, zero(T); ndrange = n)
@test Array(y)[1] == n

groupreduce_2!(backend())(y, x, +, zero(T), Val(128); ndrange = n)
@test Array(y)[1] == 128

groupreduce_2!(backend())(y, x, +, zero(T), Val(64); ndrange = n)
@test Array(y)[1] == 64
end
end
end
8 changes: 8 additions & 0 deletions test/testsuite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ include("reflection.jl")
include("examples.jl")
include("convert.jl")
include("specialfunctions.jl")
include("groupreduce.jl")

function testsuite(backend, backend_str, backend_mod, AT, DAT; skip_tests = Set{String}())
@conditional_testset "Unittests" skip_tests begin
Expand Down Expand Up @@ -92,6 +93,13 @@ function testsuite(backend, backend_str, backend_mod, AT, DAT; skip_tests = Set{
examples_testsuite(backend_str)
end

# TODO @index(Local) only works as a top-level expression on CPU.
if backend != CPU
@conditional_testset "@groupreduce" skip_tests begin
groupreduce_testsuite(backend, AT)
end
end

return
end

Expand Down
Loading