Skip to content

Commit

Permalink
Avoid the exception branch in expand (#518)
Browse files Browse the repository at this point in the history
  • Loading branch information
vchuravy authored Jan 21, 2025
1 parent e5ef261 commit 4585ca9
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 33 deletions.
32 changes: 16 additions & 16 deletions src/KernelAbstractions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ synchronize(backend)
```
"""
macro kernel(expr)
__kernel(expr, #=generate_cpu=# true, #=force_inbounds=# false)
return __kernel(expr, #=generate_cpu=# true, #=force_inbounds=# false)
end

"""
Expand All @@ -68,7 +68,7 @@ This allows for two different configurations:
"""
macro kernel(ex...)
if length(ex) == 1
__kernel(ex[1], true, false)
return __kernel(ex[1], true, false)
else
generate_cpu = true
force_inbounds = false
Expand All @@ -88,7 +88,7 @@ macro kernel(ex...)
)
end
end
__kernel(ex[end], generate_cpu, force_inbounds)
return __kernel(ex[end], generate_cpu, force_inbounds)
end
end

Expand Down Expand Up @@ -206,7 +206,7 @@ a tuple corresponding to kernel configuration. In order to get
the total size you can use `prod(@groupsize())`.
"""
macro groupsize()
quote
return quote
$groupsize($(esc(:__ctx__)))
end
end
Expand All @@ -218,7 +218,7 @@ Query the ndrange on the backend. This function returns
a tuple corresponding to kernel configuration.
"""
macro ndrange()
quote
return quote
$size($ndrange($(esc(:__ctx__))))
end
end
Expand All @@ -232,7 +232,7 @@ macro localmem(T, dims)
# Stay in sync with CUDAnative
id = gensym("static_shmem")

quote
return quote
$SharedMemory($(esc(T)), Val($(esc(dims))), Val($(QuoteNode(id))))
end
end
Expand All @@ -253,7 +253,7 @@ macro private(T, dims)
if dims isa Integer
dims = (dims,)
end
quote
return quote
$Scratchpad($(esc(:__ctx__)), $(esc(T)), Val($(esc(dims))))
end
end
Expand All @@ -265,7 +265,7 @@ Creates a private local of `mem` per item in the workgroup. This can be safely u
across [`@synchronize`](@ref) statements.
"""
macro private(expr)
esc(expr)
return esc(expr)
end

"""
Expand All @@ -275,7 +275,7 @@ end
that span workitems, or are reused across `@synchronize` statements.
"""
macro uniform(value)
esc(value)
return esc(value)
end

"""
Expand All @@ -286,7 +286,7 @@ from each thread in the workgroup are visible in from all other threads in the
workgroup.
"""
macro synchronize()
quote
return quote
$__synchronize()
end
end
Expand All @@ -303,7 +303,7 @@ workgroup. `cond` is not allowed to have any visible sideffects.
- `CPU`: This synchronization will always occur.
"""
macro synchronize(cond)
quote
return quote
$(esc(cond)) && $__synchronize()
end
end
Expand All @@ -328,7 +328,7 @@ end
```
"""
macro context()
esc(:(__ctx__))
return esc(:(__ctx__))
end

"""
Expand Down Expand Up @@ -368,7 +368,7 @@ macro print(items...)
end
end

quote
return quote
$__print($(map(esc, args)...))
end
end
Expand Down Expand Up @@ -424,7 +424,7 @@ macro index(locale, args...)
end

index_function = Symbol(:__index_, locale, :_, indexkind)
Expr(:call, GlobalRef(KernelAbstractions, index_function), esc(:__ctx__), map(esc, args)...)
return Expr(:call, GlobalRef(KernelAbstractions, index_function), esc(:__ctx__), map(esc, args)...)
end

###
Expand Down Expand Up @@ -662,7 +662,7 @@ struct Kernel{Backend, WorkgroupSize <: _Size, NDRange <: _Size, Fun}
end

function Base.similar(kernel::Kernel{D, WS, ND}, f::F) where {D, WS, ND, F}
Kernel{D, WS, ND, F}(kernel.backend, f)
return Kernel{D, WS, ND, F}(kernel.backend, f)
end

workgroupsize(::Kernel{D, WorkgroupSize}) where {D, WorkgroupSize} = WorkgroupSize
Expand Down Expand Up @@ -772,7 +772,7 @@ end
push!(args, item)
end

quote
return quote
print($(args...))
end
end
Expand Down
9 changes: 5 additions & 4 deletions src/cpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ function (obj::Kernel{CPU})(args...; ndrange = nothing, workgroupsize = nothing)
end

__run(obj, ndrange, iterspace, args, dynamic, obj.backend.static)
return nothing
end

const CPU_GRAINSIZE = 1024 # Vectorization, 4x unrolling, minimal grain size
Expand Down Expand Up @@ -161,15 +162,15 @@ end

@inline function __index_Global_Linear(ctx, idx::CartesianIndex)
I = @inbounds expand(__iterspace(ctx), __groupindex(ctx), idx)
@inbounds LinearIndices(__ndrange(ctx))[I]
return @inbounds LinearIndices(__ndrange(ctx))[I]
end

@inline function __index_Local_Cartesian(_, idx::CartesianIndex)
return idx
end

@inline function __index_Group_Cartesian(ctx, ::CartesianIndex)
__groupindex(ctx)
return __groupindex(ctx)
end

@inline function __index_Global_Cartesian(ctx, idx::CartesianIndex)
Expand All @@ -190,7 +191,7 @@ end
# CPU implementation of shared memory
###
@inline function SharedMemory(::Type{T}, ::Val{Dims}, ::Val) where {T, Dims}
MArray{__size(Dims), T}(undef)
return MArray{__size(Dims), T}(undef)
end

###
Expand All @@ -211,7 +212,7 @@ end
# https://github.com/JuliaLang/julia/issues/39308
@inline function aview(A, I::Vararg{Any, N}) where {N}
J = Base.to_indices(A, I)
Base.unsafe_view(Base._maybe_reshape_parent(A, Base.index_ndims(J...)), J...)
return Base.unsafe_view(Base._maybe_reshape_parent(A, Base.index_ndims(J...)), J...)
end

@inline function Base.getindex(A::ScratchArray{N}, idx) where {N}
Expand Down
8 changes: 5 additions & 3 deletions src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ function find_return(stmt)
result |= @capture(expr, return x_)
expr
end
result
return result
end

# XXX: Proper errors
Expand Down Expand Up @@ -103,6 +103,7 @@ function transform_gpu!(def, constargs, force_inbounds)
Expr(:block, let_constargs...),
body,
)
return nothing
end

# The hard case, transform the function for CPU execution
Expand Down Expand Up @@ -137,6 +138,7 @@ function transform_cpu!(def, constargs, force_inbounds)
Expr(:block, let_constargs...),
Expr(:block, new_stmts...),
)
return nothing
end

struct WorkgroupLoop
Expand All @@ -150,7 +152,7 @@ end
is_sync(expr) = @capture(expr, @synchronize() | @synchronize(a_))

function is_scope_construct(expr::Expr)
expr.head === :block # ||
return expr.head === :block # ||
# expr.head === :let
end

Expand All @@ -160,7 +162,7 @@ function find_sync(stmt)
result |= is_sync(expr)
expr
end
result
return result
end

# TODO proper handling of LineInfo
Expand Down
50 changes: 43 additions & 7 deletions src/nditeration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ abstract type _Size end
struct DynamicSize <: _Size end
struct StaticSize{S} <: _Size
function StaticSize{S}() where {S}
new{S::Tuple{Vararg{Int}}}()
return new{S::Tuple{Vararg{Int}}}()
end
end

Expand Down Expand Up @@ -51,11 +51,11 @@ struct NDRange{N, StaticBlocks, StaticWorkitems, DynamicBlock, DynamicWorkitems}
workitems::DynamicWorkitems

function NDRange{N, B, W}() where {N, B, W}
new{N, B, W, Nothing, Nothing}(nothing, nothing)
return new{N, B, W, Nothing, Nothing}(nothing, nothing)
end

function NDRange{N, B, W}(blocks, workitems) where {N, B, W}
new{N, B, W, typeof(blocks), typeof(workitems)}(blocks, workitems)
return new{N, B, W, typeof(blocks), typeof(workitems)}(blocks, workitems)
end
end

Expand All @@ -78,19 +78,55 @@ Base.length(range::NDRange) = length(blocks(range))
gidx = groupidx.I[I]
(gidx - 1) * stride + idx.I[I]
end
CartesianIndex(nI)
return CartesianIndex(nI)
end


"""
assume(cond::Bool)
Assume that the condition `cond` is true. This is a hint to the compiler, possibly enabling
it to optimize more aggressively.
"""
@inline assume(cond::Bool) = Base.llvmcall(
(
"""
declare void @llvm.assume(i1)
define void @entry(i8) #0 {
%cond = icmp eq i8 %0, 1
call void @llvm.assume(i1 %cond)
ret void
}
attributes #0 = { alwaysinline }""", "entry",
),
Nothing, Tuple{Bool}, cond
)

@inline function assume_nonzero(CI::CartesianIndices)
return ntuple(Val(ndims(CI))) do I
Base.@_inline_meta
indices = CI.indices[I]
assume(indices.stop > 0)
end
end

Base.@propagate_inbounds function expand(ndrange::NDRange, groupidx::Integer, idx::Integer)
expand(ndrange, blocks(ndrange)[groupidx], workitems(ndrange)[idx])
# this causes a exception branch and a div
B = blocks(ndrange)
W = workitems(ndrange)
assume_nonzero(B)
assume_nonzero(W)
return expand(ndrange, B[groupidx], workitems(ndrange)[idx])
end

Base.@propagate_inbounds function expand(ndrange::NDRange{N}, groupidx::CartesianIndex{N}, idx::Integer) where {N}
expand(ndrange, groupidx, workitems(ndrange)[idx])
return expand(ndrange, groupidx, workitems(ndrange)[idx])
end

Base.@propagate_inbounds function expand(ndrange::NDRange{N}, groupidx::Integer, idx::CartesianIndex{N}) where {N}
expand(ndrange, blocks(ndrange)[groupidx], idx)
return expand(ndrange, blocks(ndrange)[groupidx], idx)
end

"""
Expand Down
6 changes: 3 additions & 3 deletions src/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ end


function ka_code_llvm(kernel, argtypes; ndrange = nothing, workgroupsize = nothing, kwargs...)
ka_code_llvm(stdout, kernel, argtypes; ndrange = ndrange, workgroupsize = nothing, kwargs...)
return ka_code_llvm(stdout, kernel, argtypes; ndrange = ndrange, workgroupsize = nothing, kwargs...)
end

function ka_code_llvm(io::IO, kernel, argtypes; ndrange = nothing, workgroupsize = nothing, kwargs...)
Expand Down Expand Up @@ -119,7 +119,7 @@ macro ka_code_typed(ex0...)

thecall = InteractiveUtils.gen_call_with_extracted_types_and_kwargs(__module__, :ka_code_typed, ex)

quote
return quote
local $(esc(args)) = $(old_args)
# e.g. translate CuArray to CuBackendArray
$(esc(args)) = map(x -> argconvert($kern, x), $(esc(args)))
Expand Down Expand Up @@ -152,7 +152,7 @@ macro ka_code_llvm(ex0...)

thecall = InteractiveUtils.gen_call_with_extracted_types_and_kwargs(__module__, :ka_code_llvm, ex)

quote
return quote
local $(esc(args)) = $(old_args)

if isa($kern, Kernel{G} where {G <: GPU})
Expand Down

0 comments on commit 4585ca9

Please sign in to comment.