From 4585ca960217a11090e222ec8270e99ea063f4da Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Tue, 21 Jan 2025 15:08:01 +0100 Subject: [PATCH] Avoid the exception branch in expand (#518) --- src/KernelAbstractions.jl | 32 ++++++++++++------------- src/cpu.jl | 9 +++---- src/macros.jl | 8 ++++--- src/nditeration.jl | 50 +++++++++++++++++++++++++++++++++------ src/reflection.jl | 6 ++--- 5 files changed, 72 insertions(+), 33 deletions(-) diff --git a/src/KernelAbstractions.jl b/src/KernelAbstractions.jl index 086bdbef2..b82dadc54 100644 --- a/src/KernelAbstractions.jl +++ b/src/KernelAbstractions.jl @@ -50,7 +50,7 @@ synchronize(backend) ``` """ macro kernel(expr) - __kernel(expr, #=generate_cpu=# true, #=force_inbounds=# false) + return __kernel(expr, #=generate_cpu=# true, #=force_inbounds=# false) end """ @@ -68,7 +68,7 @@ This allows for two different configurations: """ macro kernel(ex...) if length(ex) == 1 - __kernel(ex[1], true, false) + return __kernel(ex[1], true, false) else generate_cpu = true force_inbounds = false @@ -88,7 +88,7 @@ macro kernel(ex...) ) end end - __kernel(ex[end], generate_cpu, force_inbounds) + return __kernel(ex[end], generate_cpu, force_inbounds) end end @@ -206,7 +206,7 @@ a tuple corresponding to kernel configuration. In order to get the total size you can use `prod(@groupsize())`. """ macro groupsize() - quote + return quote $groupsize($(esc(:__ctx__))) end end @@ -218,7 +218,7 @@ Query the ndrange on the backend. This function returns a tuple corresponding to kernel configuration. """ macro ndrange() - quote + return quote $size($ndrange($(esc(:__ctx__)))) end end @@ -232,7 +232,7 @@ macro localmem(T, dims) # Stay in sync with CUDAnative id = gensym("static_shmem") - quote + return quote $SharedMemory($(esc(T)), Val($(esc(dims))), Val($(QuoteNode(id)))) end end @@ -253,7 +253,7 @@ macro private(T, dims) if dims isa Integer dims = (dims,) end - quote + return quote $Scratchpad($(esc(:__ctx__)), $(esc(T)), Val($(esc(dims)))) end end @@ -265,7 +265,7 @@ Creates a private local of `mem` per item in the workgroup. This can be safely u across [`@synchronize`](@ref) statements. """ macro private(expr) - esc(expr) + return esc(expr) end """ @@ -275,7 +275,7 @@ end that span workitems, or are reused across `@synchronize` statements. """ macro uniform(value) - esc(value) + return esc(value) end """ @@ -286,7 +286,7 @@ from each thread in the workgroup are visible in from all other threads in the workgroup. """ macro synchronize() - quote + return quote $__synchronize() end end @@ -303,7 +303,7 @@ workgroup. `cond` is not allowed to have any visible sideffects. - `CPU`: This synchronization will always occur. """ macro synchronize(cond) - quote + return quote $(esc(cond)) && $__synchronize() end end @@ -328,7 +328,7 @@ end ``` """ macro context() - esc(:(__ctx__)) + return esc(:(__ctx__)) end """ @@ -368,7 +368,7 @@ macro print(items...) end end - quote + return quote $__print($(map(esc, args)...)) end end @@ -424,7 +424,7 @@ macro index(locale, args...) end index_function = Symbol(:__index_, locale, :_, indexkind) - Expr(:call, GlobalRef(KernelAbstractions, index_function), esc(:__ctx__), map(esc, args)...) + return Expr(:call, GlobalRef(KernelAbstractions, index_function), esc(:__ctx__), map(esc, args)...) end ### @@ -662,7 +662,7 @@ struct Kernel{Backend, WorkgroupSize <: _Size, NDRange <: _Size, Fun} end function Base.similar(kernel::Kernel{D, WS, ND}, f::F) where {D, WS, ND, F} - Kernel{D, WS, ND, F}(kernel.backend, f) + return Kernel{D, WS, ND, F}(kernel.backend, f) end workgroupsize(::Kernel{D, WorkgroupSize}) where {D, WorkgroupSize} = WorkgroupSize @@ -772,7 +772,7 @@ end push!(args, item) end - quote + return quote print($(args...)) end end diff --git a/src/cpu.jl b/src/cpu.jl index 513648f5d..ac1f970f2 100644 --- a/src/cpu.jl +++ b/src/cpu.jl @@ -43,6 +43,7 @@ function (obj::Kernel{CPU})(args...; ndrange = nothing, workgroupsize = nothing) end __run(obj, ndrange, iterspace, args, dynamic, obj.backend.static) + return nothing end const CPU_GRAINSIZE = 1024 # Vectorization, 4x unrolling, minimal grain size @@ -161,7 +162,7 @@ end @inline function __index_Global_Linear(ctx, idx::CartesianIndex) I = @inbounds expand(__iterspace(ctx), __groupindex(ctx), idx) - @inbounds LinearIndices(__ndrange(ctx))[I] + return @inbounds LinearIndices(__ndrange(ctx))[I] end @inline function __index_Local_Cartesian(_, idx::CartesianIndex) @@ -169,7 +170,7 @@ end end @inline function __index_Group_Cartesian(ctx, ::CartesianIndex) - __groupindex(ctx) + return __groupindex(ctx) end @inline function __index_Global_Cartesian(ctx, idx::CartesianIndex) @@ -190,7 +191,7 @@ end # CPU implementation of shared memory ### @inline function SharedMemory(::Type{T}, ::Val{Dims}, ::Val) where {T, Dims} - MArray{__size(Dims), T}(undef) + return MArray{__size(Dims), T}(undef) end ### @@ -211,7 +212,7 @@ end # https://github.com/JuliaLang/julia/issues/39308 @inline function aview(A, I::Vararg{Any, N}) where {N} J = Base.to_indices(A, I) - Base.unsafe_view(Base._maybe_reshape_parent(A, Base.index_ndims(J...)), J...) + return Base.unsafe_view(Base._maybe_reshape_parent(A, Base.index_ndims(J...)), J...) end @inline function Base.getindex(A::ScratchArray{N}, idx) where {N} diff --git a/src/macros.jl b/src/macros.jl index a511758dc..570e2bf45 100644 --- a/src/macros.jl +++ b/src/macros.jl @@ -6,7 +6,7 @@ function find_return(stmt) result |= @capture(expr, return x_) expr end - result + return result end # XXX: Proper errors @@ -103,6 +103,7 @@ function transform_gpu!(def, constargs, force_inbounds) Expr(:block, let_constargs...), body, ) + return nothing end # The hard case, transform the function for CPU execution @@ -137,6 +138,7 @@ function transform_cpu!(def, constargs, force_inbounds) Expr(:block, let_constargs...), Expr(:block, new_stmts...), ) + return nothing end struct WorkgroupLoop @@ -150,7 +152,7 @@ end is_sync(expr) = @capture(expr, @synchronize() | @synchronize(a_)) function is_scope_construct(expr::Expr) - expr.head === :block # || + return expr.head === :block # || # expr.head === :let end @@ -160,7 +162,7 @@ function find_sync(stmt) result |= is_sync(expr) expr end - result + return result end # TODO proper handling of LineInfo diff --git a/src/nditeration.jl b/src/nditeration.jl index 8cfdbce95..cd05b2dd4 100644 --- a/src/nditeration.jl +++ b/src/nditeration.jl @@ -13,7 +13,7 @@ abstract type _Size end struct DynamicSize <: _Size end struct StaticSize{S} <: _Size function StaticSize{S}() where {S} - new{S::Tuple{Vararg{Int}}}() + return new{S::Tuple{Vararg{Int}}}() end end @@ -51,11 +51,11 @@ struct NDRange{N, StaticBlocks, StaticWorkitems, DynamicBlock, DynamicWorkitems} workitems::DynamicWorkitems function NDRange{N, B, W}() where {N, B, W} - new{N, B, W, Nothing, Nothing}(nothing, nothing) + return new{N, B, W, Nothing, Nothing}(nothing, nothing) end function NDRange{N, B, W}(blocks, workitems) where {N, B, W} - new{N, B, W, typeof(blocks), typeof(workitems)}(blocks, workitems) + return new{N, B, W, typeof(blocks), typeof(workitems)}(blocks, workitems) end end @@ -78,19 +78,55 @@ Base.length(range::NDRange) = length(blocks(range)) gidx = groupidx.I[I] (gidx - 1) * stride + idx.I[I] end - CartesianIndex(nI) + return CartesianIndex(nI) +end + + +""" + assume(cond::Bool) + +Assume that the condition `cond` is true. This is a hint to the compiler, possibly enabling +it to optimize more aggressively. +""" +@inline assume(cond::Bool) = Base.llvmcall( + ( + """ + declare void @llvm.assume(i1) + + define void @entry(i8) #0 { + %cond = icmp eq i8 %0, 1 + call void @llvm.assume(i1 %cond) + ret void + } + + attributes #0 = { alwaysinline }""", "entry", + ), + Nothing, Tuple{Bool}, cond +) + +@inline function assume_nonzero(CI::CartesianIndices) + return ntuple(Val(ndims(CI))) do I + Base.@_inline_meta + indices = CI.indices[I] + assume(indices.stop > 0) + end end Base.@propagate_inbounds function expand(ndrange::NDRange, groupidx::Integer, idx::Integer) - expand(ndrange, blocks(ndrange)[groupidx], workitems(ndrange)[idx]) + # this causes a exception branch and a div + B = blocks(ndrange) + W = workitems(ndrange) + assume_nonzero(B) + assume_nonzero(W) + return expand(ndrange, B[groupidx], workitems(ndrange)[idx]) end Base.@propagate_inbounds function expand(ndrange::NDRange{N}, groupidx::CartesianIndex{N}, idx::Integer) where {N} - expand(ndrange, groupidx, workitems(ndrange)[idx]) + return expand(ndrange, groupidx, workitems(ndrange)[idx]) end Base.@propagate_inbounds function expand(ndrange::NDRange{N}, groupidx::Integer, idx::CartesianIndex{N}) where {N} - expand(ndrange, blocks(ndrange)[groupidx], idx) + return expand(ndrange, blocks(ndrange)[groupidx], idx) end """ diff --git a/src/reflection.jl b/src/reflection.jl index da3ba1fbf..53142cc1a 100644 --- a/src/reflection.jl +++ b/src/reflection.jl @@ -34,7 +34,7 @@ end function ka_code_llvm(kernel, argtypes; ndrange = nothing, workgroupsize = nothing, kwargs...) - ka_code_llvm(stdout, kernel, argtypes; ndrange = ndrange, workgroupsize = nothing, kwargs...) + return ka_code_llvm(stdout, kernel, argtypes; ndrange = ndrange, workgroupsize = nothing, kwargs...) end function ka_code_llvm(io::IO, kernel, argtypes; ndrange = nothing, workgroupsize = nothing, kwargs...) @@ -119,7 +119,7 @@ macro ka_code_typed(ex0...) thecall = InteractiveUtils.gen_call_with_extracted_types_and_kwargs(__module__, :ka_code_typed, ex) - quote + return quote local $(esc(args)) = $(old_args) # e.g. translate CuArray to CuBackendArray $(esc(args)) = map(x -> argconvert($kern, x), $(esc(args))) @@ -152,7 +152,7 @@ macro ka_code_llvm(ex0...) thecall = InteractiveUtils.gen_call_with_extracted_types_and_kwargs(__module__, :ka_code_llvm, ex) - quote + return quote local $(esc(args)) = $(old_args) if isa($kern, Kernel{G} where {G <: GPU})