Skip to content

Commit

Permalink
Switch global state to keying everything on the context.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Jan 23, 2025
1 parent 301f1d7 commit 5568373
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 90 deletions.
8 changes: 0 additions & 8 deletions lib/cl/platform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,6 @@ function devices(p::Platform, dtype)
return Device[Device(id) for id in result]
end

function default_device(p::Platform)
devs = devices(p, CL_DEVICE_TYPE_DEFAULT)
isempty(devs) && return nothing
# XXX: clGetDeviceIDs documents CL_DEVICE_TYPE_DEFAULT should only return one device,
# but it's been observed to return multiple devices on some platforms...
return first(devs)
end

devices(p::Platform) = devices(p, CL_DEVICE_TYPE_ALL)

function devices(p::Platform, dtype::Symbol)
Expand Down
192 changes: 119 additions & 73 deletions lib/cl/state.jl
Original file line number Diff line number Diff line change
@@ -1,29 +1,90 @@
## platform selection
# global state

function platform()
get!(task_local_storage(), :CLPlatform) do
ps = platforms()
if isempty(ps)
throw(ArgumentError("No OpenCL platforms found"))
function clear_task_local_storage!()
# the primary key for all task-local state is the context
delete!(task_local_storage(), :CLContext)

# all other state is derived
delete!(task_local_storage(), :CLDevice)
delete!(task_local_storage(), :CLPlatform)
delete!(task_local_storage(), :CLQueue)
delete!(task_local_storage(), :CLMemoryBackend)
end


## context creation

# we maintain a single global context per device
const device_contexts = Dict{Device, Context}()
const device_context_lock = ReentrantLock()
function device_context(dev::Device)
return @lock device_context_lock begin
get!(device_contexts, dev) do
device_contexts[dev] = Context(dev)
end
end
end

# prefer platforms that implement the full profile
idx = findfirst(ps) do p
p.profile == "FULL_PROFILE"
function context()
return get!(task_local_storage(), :CLContext) do
dev = if haskey(task_local_storage(), :CLDevice)
device()
elseif haskey(task_local_storage(), :CLPlatform)
default_device(platform())
else
default_device(default_platform())
end
isnothing(idx) || return ps[idx]
isnothing(dev) && throw(ArgumentError("No OpenCL devices found"))
device_context(dev)
end::Context
end

function context!(ctx::Context)
clear_task_local_storage!()
task_local_storage(:CLContext, ctx)
return ctx
end

# temporarily switch the current context to a different context
function context!(f::Base.Callable, args...)
old = context()
context!(args...)
try
f()
finally
context!(old)
end
end


## platform selection

function default_platform()
ps = platforms()
if isempty(ps)
throw(ArgumentError("No OpenCL platforms found"))
end

# prefer platforms that implement the full profile
idx = findfirst(ps) do p
p.profile == "FULL_PROFILE"
end
isnothing(idx) || return ps[idx]

# otherwise, just return the first platform
return first(ps)
end

# otherwise, just return the first platform
return first(ps)
function platform()
get!(task_local_storage(), :CLPlatform) do
device().platform
end::Platform
end

# allow overriding with a specific platform
function platform!(p::Platform)
clear_task_local_storage!()
task_local_storage(:CLPlatform, p)
delete!(task_local_storage(), :CLDevice)
delete!(task_local_storage(), :CLDeviceState)
delete!(task_local_storage(), :CLQueue)
return p
end

Expand All @@ -49,20 +110,24 @@ end

## device selection

function default_device(p::Platform)
devs = devices(p, CL_DEVICE_TYPE_DEFAULT)
isempty(devs) && return nothing
# XXX: clGetDeviceIDs documents CL_DEVICE_TYPE_DEFAULT should only return one device,
# but it's been observed to return multiple devices on some platforms...
return first(devs)
end

function device()
get!(task_local_storage(), :CLDevice) do
dev = default_device(platform())
isnothing(dev) && throw(ArgumentError("No OpenCL devices found"))
dev
only(context().devices)
end::Device
end

# allow overriding with a specific device
function device!(dev::Device)
clear_task_local_storage!()
task_local_storage(:CLDevice, dev)
task_local_storage(:CLPlatform, dev.platform)
delete!(task_local_storage(), :CLDeviceState)
delete!(task_local_storage(), :CLQueue)
return dev
end

Expand All @@ -73,65 +138,45 @@ function device!(dtype::Symbol)
device!(first(dev))
end

# temporarily switch the current device to a different device
function device!(f::Base.Callable, args...)
old = device()
device!(args...)
try
f()
finally
device!(old)
end
end

## per-device state

# each device is associated with a single context
# (and some other state we only want to set up once)
## memory back-end

abstract type AbstractMemoryBackend end
struct SVMBackend <: AbstractMemoryBackend end
struct USMBackend <: AbstractMemoryBackend end

struct DeviceState
context::Context
backend::AbstractMemoryBackend
end

const device_states = Dict{Device, DeviceState}()
const device_state_lock = ReentrantLock()
function device_state(dev::Device = device())
return get!(task_local_storage(), :CLDeviceState) do
@lock device_state_lock begin
get!(device_states, dev) do
ctx = Context(dev)

# validate memory support

# determine if USM is supported
usm = if usm_supported(dev)
caps = usm_capabilities(dev)
caps.host.access && caps.device.access
else
false
end

# determine if SVM is available (if needed)
if !usm
caps = svm_capabilities(dev)
if !caps.coarse_grain_buffer
error("Device $dev does not support USM or coarse-grained SVM, either of which is required by OpenCL.jl")
end
end

backend = usm ? USMBackend() : SVMBackend()
device_states[dev] = DeviceState(ctx, backend)
end
function memory_backend()
return get!(task_local_storage(), :CLMemoryBackend) do
dev = device()

# determine if USM is supported
usm = if usm_supported(dev)
caps = usm_capabilities(dev)
caps.host.access && caps.device.access
else
false
end
end::DeviceState
end

context(dev::Device = device()) = device_state(dev).context
memory_backend(dev::Device = device()) = device_state(dev).backend
# determine if SVM is available (if needed)
if !usm
caps = svm_capabilities(dev)
if !caps.coarse_grain_buffer
error("Device $dev does not support USM or coarse-grained SVM, either of which is required by OpenCL.jl")
end
end

# temporarily switch the current device to a different device
function device!(f::Base.Callable, args...)
old = device()
device!(args...)
try
f()
finally
device!(old)
usm ? USMBackend() : SVMBackend()
end
end

Expand All @@ -147,16 +192,17 @@ function queue()
Dict{Device, CmdQueue}()
end

queue = get!(queues, dev) do
get!(queues, dev) do
CmdQueue()
end
task_local_storage(:CLQueue, queue)
queue
end::CmdQueue
end

# switch the current task to a different queue
function queue!(q::CmdQueue)
if q.device != device()
throw(ArgumentError("Cannot switch to a queue on a different device"))
end
task_local_storage(:CLQueue, q)
return q
end
Expand Down
16 changes: 7 additions & 9 deletions src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -370,18 +370,16 @@ for (srcty, dstty) in [(:Array, :CLArray), (:CLArray, :Array), (:CLArray, :CLArr
nbytes = N * sizeof(T)
nbytes == 0 && return

dev = if $dstty == CLArray
device(dst)
ctx = if $dstty == CLArray
context(dst)
else
device(src)
context(src)
end
cl.device!(dev) do
if cl.memory_backend(dev) == cl.SVMBackend()
cl.context!(ctx) do
if cl.memory_backend() == cl.SVMBackend()
cl.enqueue_svm_copy(pointer(dst, dst_off), pointer(src, src_off), nbytes; blocking)
elseif cl.memory_backend(dev) == cl.USMBackend()
elseif cl.memory_backend() == cl.USMBackend()
cl.enqueue_usm_copy(pointer(dst, dst_off), pointer(src, src_off), nbytes; blocking)
else
error(cl.memory_backend(dev))
end
end
end
Expand Down Expand Up @@ -503,7 +501,7 @@ function Base.resize!(a::CLVector{T}, n::Integer) where {T}

# replace the data with a new CL. this 'unshares' the array.
# as a result, we can safely support resizing unowned buffers.
new_data = cl.device!(device(a)) do
new_data = cl.context!(context(a)) do
mem = alloc(memtype(a), bufsize; alignment=Base.datatype_alignment(T))
ptr = convert(CLPtr{T}, mem)
m = min(length(a), n)
Expand Down

0 comments on commit 5568373

Please sign in to comment.