-
Notifications
You must be signed in to change notification settings - Fork 70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement groupreduce API #559
base: main
Are you sure you want to change the base?
Changes from all commits
e1a110f
ff4097f
6a35eb8
224e8c8
4a8e707
7c923fb
a647992
cbc8bd5
bb77270
db5abc5
618c840
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
export @groupreduce, @warp_groupreduce | ||
|
||
""" | ||
@groupreduce op val neutral [groupsize] | ||
|
||
Perform group reduction of `val` using `op`. | ||
If backend supports warp reduction, it will use it instead of thread reduction. | ||
|
||
# Arguments | ||
|
||
- `neutral` should be a neutral w.r.t. `op`, such that `op(neutral, x) == x`. | ||
|
||
- `groupsize` specifies size of the workgroup. | ||
If a kernel does not specifies `groupsize` statically, then it is required to | ||
provide `groupsize`. | ||
Also can be used to perform reduction accross first `groupsize` threads | ||
(if `groupsize < @groupsize()`). | ||
|
||
# Returns | ||
|
||
Result of the reduction. | ||
""" | ||
macro groupreduce(op, val) | ||
:(__thread_groupreduce($(esc(:__ctx__)), $(esc(op)), $(esc(val)), Val(prod($groupsize($(esc(:__ctx__))))))) | ||
end | ||
macro groupreduce(op, val, groupsize) | ||
:(__thread_groupreduce($(esc(:__ctx__)), $(esc(op)), $(esc(val)), Val($(esc(groupsize))))) | ||
end | ||
|
||
macro warp_groupreduce(op, val, neutral) | ||
:(__warp_groupreduce($(esc(:__ctx__)), $(esc(op)), $(esc(val)), $(esc(neutral)), Val(prod($groupsize($(esc(:__ctx__))))))) | ||
end | ||
macro warp_groupreduce(op, val, neutral, groupsize) | ||
:(__warp_groupreduce($(esc(:__ctx__)), $(esc(op)), $(esc(val)), $(esc(neutral)), Val($(esc(groupsize))))) | ||
end | ||
|
||
function __thread_groupreduce(__ctx__, op, val::T, ::Val{groupsize}) where {T, groupsize} | ||
storage = @localmem T groupsize | ||
|
||
local_idx = @index(Local) | ||
@inbounds local_idx ≤ groupsize && (storage[local_idx] = val) | ||
@synchronize() | ||
|
||
s::UInt64 = groupsize ÷ 0x02 | ||
while s > 0x00 | ||
if (local_idx - 0x01) < s | ||
other_idx = local_idx + s | ||
if other_idx ≤ groupsize | ||
@inbounds storage[local_idx] = op(storage[local_idx], storage[other_idx]) | ||
end | ||
end | ||
@synchronize() | ||
s >>= 0x01 | ||
end | ||
|
||
if local_idx == 0x01 | ||
@inbounds val = storage[local_idx] | ||
end | ||
return val | ||
end | ||
|
||
# Warp groupreduce. | ||
|
||
# NOTE: Backends should implement these two device functions (with `@device_override`). | ||
function __shfl_down end | ||
function __supports_warp_reduction() | ||
return false | ||
end | ||
|
||
# Assume warp is 32 lanes. | ||
const __warpsize = UInt32(32) | ||
# Maximum number of warps (for a groupsize = 1024). | ||
const __warp_bins = UInt32(32) | ||
|
||
@inline function __warp_reduce(val, op) | ||
offset::UInt32 = __warpsize ÷ 0x02 | ||
while offset > 0x00 | ||
val = op(val, __shfl_down(val, offset)) | ||
offset >>= 0x01 | ||
end | ||
return val | ||
end | ||
|
||
function __warp_groupreduce(__ctx__, op, val::T, neutral::T, ::Val{groupsize}) where {T, groupsize} | ||
storage = @localmem T __warp_bins | ||
|
||
local_idx = @index(Local) | ||
lane = (local_idx - 0x01) % __warpsize + 0x01 | ||
warp_id = (local_idx - 0x01) ÷ __warpsize + 0x01 | ||
|
||
# Each warp performs a reduction and writes results into its own bin in `storage`. | ||
val = __warp_reduce(val, op) | ||
@inbounds lane == 0x01 && (storage[warp_id] = val) | ||
@synchronize() | ||
|
||
# Final reduction of the `storage` on the first warp. | ||
within_storage = (local_idx - 0x01) < groupsize ÷ __warpsize | ||
@inbounds val = within_storage ? storage[lane] : neutral | ||
warp_id == 0x01 && (val = __warp_reduce(val, op)) | ||
return val | ||
end |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,35 @@ | ||||||||||||||||||||||||||
@kernel cpu=false function groupreduce_1!(y, x, op, neutral) | ||||||||||||||||||||||||||
i = @index(Global) | ||||||||||||||||||||||||||
Comment on lines
+1
to
+2
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Comment on lines
+1
to
+2
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Comment on lines
+1
to
+2
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||
val = i > length(x) ? neutral : x[i] | ||||||||||||||||||||||||||
res = @groupreduce(op, val, neutral) | ||||||||||||||||||||||||||
i == 1 && (y[1] = res) | ||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
@kernel cpu=false function groupreduce_2!(y, x, op, neutral, ::Val{groupsize}) where {groupsize} | ||||||||||||||||||||||||||
i = @index(Global) | ||||||||||||||||||||||||||
val = i > length(x) ? neutral : x[i] | ||||||||||||||||||||||||||
res = @groupreduce(op, val, neutral, groupsize) | ||||||||||||||||||||||||||
i == 1 && (y[1] = res) | ||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
function groupreduce_testsuite(backend, AT) | ||||||||||||||||||||||||||
# TODO should be a better way of querying max groupsize | ||||||||||||||||||||||||||
groupsizes = "$backend" == "oneAPIBackend" ? | ||||||||||||||||||||||||||
(256,) : | ||||||||||||||||||||||||||
(256, 512, 1024) | ||||||||||||||||||||||||||
@testset "@groupreduce" begin | ||||||||||||||||||||||||||
pxl-th marked this conversation as resolved.
Show resolved
Hide resolved
pxl-th marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
pxl-th marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||
@testset "T=$T, n=$n" for T in (Float16, Float32, Int16, Int32, Int64), n in groupsizes | ||||||||||||||||||||||||||
x = AT(ones(T, n)) | ||||||||||||||||||||||||||
y = AT(zeros(T, 1)) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
groupreduce_1!(backend(), n)(y, x, +, zero(T); ndrange = n) | ||||||||||||||||||||||||||
@test Array(y)[1] == n | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
groupreduce_2!(backend())(y, x, +, zero(T), Val(128); ndrange = n) | ||||||||||||||||||||||||||
@test Array(y)[1] == 128 | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
groupreduce_2!(backend())(y, x, +, zero(T), Val(64); ndrange = n) | ||||||||||||||||||||||||||
@test Array(y)[1] == 64 | ||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently this is not legal.
#262 might need to wait until #556
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(I assume this code is GPU only anyways)