Skip to content

Commit

Permalink
test with async dataloader, make source optional to yield
Browse files Browse the repository at this point in the history
  • Loading branch information
rmosolgo committed Jan 31, 2025
1 parent 2e13fc1 commit 201a42f
Show file tree
Hide file tree
Showing 13 changed files with 1,838 additions and 90 deletions.
5 changes: 5 additions & 0 deletions lib/graphql/current.rb
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,10 @@ def self.field
def self.dataloader_source_class
Fiber[:__graphql_current_dataloader_source]&.class
end

# @return [GraphQL::Dataloader::Source, nil] The currently-running source, if there is one
def self.dataloader_source
Fiber[:__graphql_current_dataloader_source]
end
end
end
6 changes: 3 additions & 3 deletions lib/graphql/dataloader.rb
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def with(source_class, *batch_args, **batch_kwargs)
# Dataloader will resume the fiber after the requested data has been loaded (by another Fiber).
#
# @return [void]
def yield(source)
def yield(source = Fiber[:__graphql_current_dataloader_source])
trace = Fiber[:__graphql_current_multiplex]&.current_trace
trace&.dataloader_fiber_yield(source)
Fiber.yield
Expand Down Expand Up @@ -195,7 +195,7 @@ def run
next_source_fibers = []
first_pass = true
manager = spawn_fiber do
trace&.begin_dataloader
trace&.begin_dataloader(self)
while first_pass || !job_fibers.empty?
first_pass = false

Expand All @@ -222,7 +222,7 @@ def run
end
end

trace&.end_dataloader
trace&.end_dataloader(self)
end

run_fiber(manager)
Expand Down
22 changes: 17 additions & 5 deletions lib/graphql/dataloader/async_dataloader.rb
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,20 @@
module GraphQL
class Dataloader
class AsyncDataloader < Dataloader
def yield(_source)
def yield(source = Fiber[:__graphql_current_dataloader_source])
trace = Fiber[:__graphql_current_multiplex]&.current_trace
trace&.dataloader_fiber_yield(source)
if (condition = Fiber[:graphql_dataloader_next_tick])
condition.wait
else
Fiber.yield
end
trace&.dataloader_fiber_resume(source)
nil
end

def run
trace = Fiber[:__graphql_current_multiplex]&.current_trace
jobs_fiber_limit, total_fiber_limit = calculate_fiber_limit
job_fibers = []
next_job_fibers = []
Expand All @@ -20,11 +24,12 @@ def run
first_pass = true
sources_condition = Async::Condition.new
manager = spawn_fiber do
trace&.begin_dataloader(self)
while first_pass || !job_fibers.empty?
first_pass = false
fiber_vars = get_fiber_variables

while (f = (job_fibers.shift || (((job_fibers.size + next_job_fibers.size + source_tasks.size) < jobs_fiber_limit) && spawn_job_fiber(nil))))
while (f = (job_fibers.shift || (((job_fibers.size + next_job_fibers.size + source_tasks.size) < jobs_fiber_limit) && spawn_job_fiber(trace))))
if f.alive?
finished = run_fiber(f)
if !finished
Expand All @@ -38,7 +43,7 @@ def run
Sync do |root_task|
set_fiber_variables(fiber_vars)
while !source_tasks.empty? || @source_cache.each_value.any? { |group_sources| group_sources.each_value.any?(&:pending?) }
while (task = (source_tasks.shift || (((job_fibers.size + next_job_fibers.size + source_tasks.size + next_source_tasks.size) < total_fiber_limit) && spawn_source_task(root_task, sources_condition))))
while (task = (source_tasks.shift || (((job_fibers.size + next_job_fibers.size + source_tasks.size + next_source_tasks.size) < total_fiber_limit) && spawn_source_task(root_task, sources_condition, trace))))
if task.alive?
root_task.yield # give the source task a chance to run
next_source_tasks << task
Expand All @@ -50,6 +55,7 @@ def run
end
end
end
trace&.end_dataloader(self)
end

manager.resume
Expand All @@ -63,7 +69,7 @@ def run

private

def spawn_source_task(parent_task, condition)
def spawn_source_task(parent_task, condition, trace)
pending_sources = nil
@source_cache.each_value do |source_by_batch_params|
source_by_batch_params.each_value do |source|
Expand All @@ -77,10 +83,16 @@ def spawn_source_task(parent_task, condition)
if pending_sources
fiber_vars = get_fiber_variables
parent_task.async do
trace&.dataloader_spawn_source_fiber(pending_sources)
set_fiber_variables(fiber_vars)
Fiber[:graphql_dataloader_next_tick] = condition
pending_sources.each(&:run_pending_keys)
pending_sources.each do |s|
trace&.begin_dataloader_source(s)
s.run_pending_keys
trace&.end_dataloader_source(s)
end
cleanup_fiber
trace&.dataloader_fiber_exit
end
end
end
Expand Down
10 changes: 5 additions & 5 deletions lib/graphql/tracing/perfetto_trace.rb
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def dataloader_fiber_yield(source)
if (flow_id = ls.track_event.flow_ids.first)
# got it
else
flow_id = rand(999_999)
flow_id = ls.track_event.name.object_id
ls.track_event = dup_with(ls.track_event, {flow_ids: [flow_id] })
end
@flow_ids[source] << flow_id
Expand Down Expand Up @@ -395,7 +395,7 @@ def dataloader_fiber_exit
super
end

def begin_dataloader
def begin_dataloader(dl)
@packets << TracePacket.new(
timestamp: ts,
track_event: TrackEvent.new(
Expand All @@ -408,15 +408,15 @@ def begin_dataloader
@did = fid
@packets << TracePacket.new(
track_descriptor: TrackDescriptor.new(
uuid: fid,
name: "Dataloader Fiber ##{fid}",
uuid: @did,
name: "Dataloader Fiber ##{@did}",
parent_uuid: @main_fiber_id,
)
)
super
end

def end_dataloader
def end_dataloader(dl)
@packets << TracePacket.new(
timestamp: ts,
track_event: TrackEvent.new(
Expand Down
6 changes: 4 additions & 2 deletions lib/graphql/tracing/trace.rb
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,13 @@ def resolve_type_lazy(query:, type:, object:)
end

# A dataloader run is starting
# @param dataloader [GraphQL::Dataloader]
# @return [void]
def begin_dataloader; end
def begin_dataloader(dataloader); end
# A dataloader run has ended
# @param dataloder [GraphQL::Dataloader]
# @return [void]
def end_dataloader; end
def end_dataloader(dataloader); end

# A source with pending keys is about to fetch
# @param source [GraphQL::Dataloader::Source]
Expand Down
66 changes: 57 additions & 9 deletions spec/graphql/dataloader/async_dataloader_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,6 @@ def fiber_local_context(key:)
module AsyncDataloaderAssertions
def self.included(child_class)
child_class.class_eval do
before do
AsyncSchema::KeyWaitForSource.reset
end

it "works with sources" do
dataloader = GraphQL::Dataloader::AsyncDataloader.new
r1 = dataloader.with(AsyncSchema::SleepSource, :s1).request(0.1)
Expand Down Expand Up @@ -181,7 +177,7 @@ def self.included(child_class)

it "works with GraphQL" do
started_at = Time.now
res = AsyncSchema.execute("{ s1: sleep(duration: 0.1) s2: sleep(duration: 0.2) s3: sleep(duration: 0.3) }")
res = @schema.execute("{ s1: sleep(duration: 0.1) s2: sleep(duration: 0.2) s3: sleep(duration: 0.3) }")
ended_at = Time.now
assert_equal({"s1"=>0.1, "s2"=>0.2, "s3"=>0.3}, res["data"])
assert_in_delta 0.3, ended_at - started_at, 0.05, "IO ran in parallel"
Expand All @@ -208,7 +204,7 @@ def self.included(child_class)
}
GRAPHQL
started_at = Time.now
res = AsyncSchema.execute(query_str)
res = @schema.execute(query_str)
ended_at = Time.now

expected_data = {
Expand Down Expand Up @@ -251,7 +247,7 @@ def self.included(child_class)
}
GRAPHQL
started_at = Time.now
res = AsyncSchema.execute(query_str)
res = @schema.execute(query_str)
ended_at = Time.now

expected_data = {
Expand Down Expand Up @@ -279,7 +275,7 @@ def self.included(child_class)
GRAPHQL

t1 = Time.now
result = AsyncSchema.execute(query_str)
result = @schema.execute(query_str)
t2 = Time.now
assert_equal ["a", "b", "c"], result["data"]["listWaiters"].map { |lw| lw["waiter"]["tag"]}
# The field itself waits 0.1
Expand All @@ -297,15 +293,67 @@ def self.included(child_class)
}
GRAPHQL

result = AsyncSchema.execute(query_str)
result = @schema.execute(query_str)
assert_equal value, result['data']['fiberLocalContext']
end
end
end
end

describe "with async" do
before do
@schema = AsyncSchema
AsyncSchema::KeyWaitForSource.reset
end
include AsyncDataloaderAssertions
end

describe "with perfetto trace turned on" do
class TraceAsyncSchema < AsyncSchema
trace_with GraphQL::Tracing::PerfettoTrace
use GraphQL::Dataloader::AsyncDataloader
end

before do
@schema = TraceAsyncSchema
AsyncSchema::KeyWaitForSource.reset
end

include AsyncDataloaderAssertions
include PerfettoSnapshot

focus
it "produces a trace" do
query_str = <<-GRAPHQL
{
s1: sleeper(duration: 0.1) {
sleeper(duration: 0.1) {
sleeper(duration: 0.1) {
duration
}
}
}
s2: sleeper(duration: 0.2) {
sleeper(duration: 0.1) {
duration
}
}
s3: sleeper(duration: 0.3) {
duration
}
}
GRAPHQL
res = @schema.execute(query_str)
if ENV["DUMP_PERFETTO"]
res.context.query.current_trace.write(file: "perfetto.dump")
end

json = res.context.query.current_trace.write(file: nil, debug_json: true)
data = JSON.parse(json)


check_snapshot(data, "example.json")
end
end
end
end
Loading

0 comments on commit 201a42f

Please sign in to comment.