From d1311be5c164dae69a9715b77772f880552c752f Mon Sep 17 00:00:00 2001 From: Antoniu Pop Date: Thu, 6 Feb 2025 12:41:32 +0000 Subject: [PATCH] fix(scheduler): enable per FHE computation error return from scheduler --- fhevm-engine/coprocessor/src/tfhe_worker.rs | 32 ++-- fhevm-engine/executor/src/server.rs | 19 ++- fhevm-engine/scheduler/src/dfg.rs | 13 +- fhevm-engine/scheduler/src/dfg/scheduler.rs | 160 +++++++++++--------- fhevm-engine/scheduler/src/dfg/types.rs | 3 +- 5 files changed, 135 insertions(+), 92 deletions(-) diff --git a/fhevm-engine/coprocessor/src/tfhe_worker.rs b/fhevm-engine/coprocessor/src/tfhe_worker.rs index c1bb74e7..6c5a3086 100644 --- a/fhevm-engine/coprocessor/src/tfhe_worker.rs +++ b/fhevm-engine/coprocessor/src/tfhe_worker.rs @@ -1,11 +1,13 @@ +use crate::types::CoprocessorError; use crate::{db_queries::populate_cache_with_tenant_keys, types::TfheTenantKeys}; -use fhevm_engine_common::types::{Handle, SupportedFheCiphertexts}; +use fhevm_engine_common::types::{FhevmError, Handle, SupportedFheCiphertexts}; use fhevm_engine_common::{tfhe_ops::current_ciphertext_version, types::SupportedFheOperations}; use itertools::Itertools; use lazy_static::lazy_static; use opentelemetry::trace::{Span, TraceContextExt, Tracer}; use opentelemetry::KeyValue; use prometheus::{register_int_counter, IntCounter}; +use scheduler::dfg::types::SchedulerError; use scheduler::dfg::{scheduler::Scheduler, types::DFGTaskInput, DFGraph}; use sqlx::{postgres::PgListener, query, Acquire}; use std::{ @@ -357,29 +359,37 @@ async fn tfhe_worker_cycle( let keys = rk.get(tenant_id).expect("Can't get tenant key from cache"); // Schedule computations in parallel as dependences allow + tfhe::set_server_key(keys.sks.clone()); let mut sched = Scheduler::new(&mut graph.graph, args.coprocessor_fhe_threads); sched.schedule(keys.sks.clone()).await?; } // Extract the results from the graph - let res = graph.get_results().unwrap(); + let mut res = graph.get_results(); for (idx, w) in work.iter().enumerate() { // Filter out computations that could not complete if uncomputable.contains_key(&idx) { continue; } - let r = res.iter().find(|(h, _)| *h == w.output_handle).unwrap(); - { - let mut rk = tenant_key_cache.write().await; - let keys = rk - .get(&w.tenant_id) - .expect("Can't get tenant key from cache"); - tfhe::set_server_key(keys.sks.clone()); - } + let r = &mut res + .iter_mut() + .find(|(h, _)| *h == w.output_handle) + .unwrap() + .1; + let finished_work_unit: Result< _, (Box<(dyn std::error::Error + Send + Sync)>, i32, Vec), - > = Ok((w, r.1 .1, &r.1 .2)); + > = r + .as_mut() + .map(|rok| (w, rok.1, std::mem::take(&mut rok.2))) + .map_err(|_rerr| { + ( + CoprocessorError::SchedulerError(SchedulerError::SchedulerError).into(), + w.tenant_id, + w.output_handle.clone(), + ) + }); match finished_work_unit { Ok((w, db_type, db_bytes)) => { let mut s = tracer.start_with_context("insert_ct_into_db", &loop_ctx); diff --git a/fhevm-engine/executor/src/server.rs b/fhevm-engine/executor/src/server.rs index 6196c621..8f17ad0f 100644 --- a/fhevm-engine/executor/src/server.rs +++ b/fhevm-engine/executor/src/server.rs @@ -120,13 +120,22 @@ impl FhevmExecutor for FhevmExecutorService { now.elapsed().unwrap().as_millis() ); // Extract the results from the graph - match graph.get_results() { - Ok(mut result_cts) => Some(Resp::ResultCiphertexts(ResultCiphertexts { - ciphertexts: result_cts + let results = graph.get_results(); + let outputs: Result))>> = + results + .into_iter() + .map(|(h, output)| match output { + Ok(output) => Ok((h, output)), + Err(e) => Err(e), + }) + .collect(); + match outputs { + Ok(mut outputs) => Some(Resp::ResultCiphertexts(ResultCiphertexts { + ciphertexts: outputs .iter_mut() - .map(|(h, ct)| CompressedCiphertext { + .map(|(h, output)| CompressedCiphertext { handle: h.clone(), - serialization: std::mem::take(&mut ct.2), + serialization: std::mem::take(&mut output.2), }) .collect(), })), diff --git a/fhevm-engine/scheduler/src/dfg.rs b/fhevm-engine/scheduler/src/dfg.rs index e0c5235b..aca11d4d 100644 --- a/fhevm-engine/scheduler/src/dfg.rs +++ b/fhevm-engine/scheduler/src/dfg.rs @@ -71,16 +71,19 @@ impl DFGraph { pub fn get_results( &mut self, - ) -> Result))>> { + ) -> Vec<(Handle, Result<(SupportedFheCiphertexts, i16, Vec)>)> { let mut res = Vec::with_capacity(self.graph.node_count()); for index in 0..self.graph.node_count() { let node = self.graph.node_weight_mut(NodeIndex::new(index)).unwrap(); - if let Some(ct) = &node.result { - res.push((node.result_handle.clone(), ct.clone())); + if let Some(ct) = std::mem::take(&mut node.result) { + res.push((node.result_handle.clone(), ct)); } else { - return Err(SchedulerError::DataflowGraphError.into()); + res.push(( + node.result_handle.clone(), + Err(SchedulerError::DataflowGraphError.into()), + )); } } - Ok(res) + res } } diff --git a/fhevm-engine/scheduler/src/dfg/scheduler.rs b/fhevm-engine/scheduler/src/dfg/scheduler.rs index b918cc5a..5c1397f9 100644 --- a/fhevm-engine/scheduler/src/dfg/scheduler.rs +++ b/fhevm-engine/scheduler/src/dfg/scheduler.rs @@ -98,7 +98,7 @@ impl<'a> Scheduler<'a> { } async fn schedule_fine_grain(&mut self, server_key: tfhe::ServerKey) -> Result<()> { - let mut set: JoinSet))>> = + let mut set: JoinSet<(usize, Result<(SupportedFheCiphertexts, i16, Vec)>)> = JoinSet::new(); tfhe::set_server_key(server_key.clone()); // Prime the scheduler with all nodes without dependences @@ -127,36 +127,39 @@ impl<'a> Scheduler<'a> { } // Get results from computations and update dependences of remaining computations while let Some(result) = set.join_next().await { - let output = result??; - let index = output.0; + let result = result?; + let index = result.0; let node_index = NodeIndex::new(index); - // Satisfy deps from the executed task - for edge in self.edges.edges_directed(node_index, Direction::Outgoing) { - let sks = server_key.clone(); - let child_index = edge.target(); - let child_node = self.graph.node_weight_mut(child_index).unwrap(); - child_node.inputs[*edge.weight() as usize] = - DFGTaskInput::Value(output.1 .0.clone()); - if Self::is_ready(child_node) { - let opcode = child_node.opcode; - let inputs: Result> = child_node - .inputs - .iter() - .map(|i| match i { - DFGTaskInput::Value(i) => Ok(i.clone()), - DFGTaskInput::Compressed((t, c)) => { - SupportedFheCiphertexts::decompress(*t, c) - } - _ => Err(SchedulerError::UnsatisfiedDependence.into()), - }) - .collect(); - set.spawn_blocking(move || { - tfhe::set_server_key(sks.clone()); - run_computation(opcode, inputs, child_index.index()) - }); + if let Ok(output) = &result.1 { + // Satisfy deps from the executed task + for edge in self.edges.edges_directed(node_index, Direction::Outgoing) { + let sks = server_key.clone(); + let child_index = edge.target(); + let child_node = self.graph.node_weight_mut(child_index).unwrap(); + child_node.inputs[*edge.weight() as usize] = + DFGTaskInput::Value(output.0.clone()); + if Self::is_ready(child_node) { + let opcode = child_node.opcode; + let inputs: Result> = child_node + .inputs + .iter() + .map(|i| match i { + DFGTaskInput::Value(i) => Ok(i.clone()), + DFGTaskInput::Compressed((t, c)) => { + SupportedFheCiphertexts::decompress(*t, c) + } + _ => Err(SchedulerError::UnsatisfiedDependence.into()), + }) + .collect(); + set.spawn_blocking(move || { + tfhe::set_server_key(sks.clone()); + run_computation(opcode, inputs, child_index.index()) + }); + } } } - self.graph[node_index].result = Some(output.1); + let node_index = NodeIndex::new(result.0); + self.graph[node_index].result = Some(result.1); } Ok(()) } @@ -167,12 +170,10 @@ impl<'a> Scheduler<'a> { server_key: tfhe::ServerKey, ) -> Result<()> { tfhe::set_server_key(server_key.clone()); - let mut set: JoinSet< - Result<( - Vec<(usize, (SupportedFheCiphertexts, i16, Vec))>, - NodeIndex, - )>, - > = JoinSet::new(); + let mut set: JoinSet<( + Vec<(usize, Result<(SupportedFheCiphertexts, i16, Vec)>)>, + NodeIndex, + )> = JoinSet::new(); let mut execution_graph: Dag = Dag::default(); let _ = match strategy { PartitionStrategy::MaxLocality => { @@ -205,18 +206,25 @@ impl<'a> Scheduler<'a> { } // Get results from computations and update dependences of remaining computations while let Some(result) = set.join_next().await { - let mut output = result??; - let task_index = output.1; - while let Some(o) = output.0.pop() { + let mut result = result?; + let task_index = result.1; + while let Some(o) = result.0.pop() { let index = o.0; let node_index = NodeIndex::new(index); - // Satisfy deps from the executed computation in the DFG - for edge in self.edges.edges_directed(node_index, Direction::Outgoing) { - let child_index = edge.target(); - let child_node = self.graph.node_weight_mut(child_index).unwrap(); - if !child_node.inputs.is_empty() { - child_node.inputs[*edge.weight() as usize] = - DFGTaskInput::Value(o.1 .0.clone()); + // If this node result is an error, we can't satisfy + // any dependences with it, so skip - all dependences + // on this will remain unsatisfied and result in + // further errors. + if o.1.is_ok() { + // Satisfy deps from the executed computation in the DFG + for edge in self.edges.edges_directed(node_index, Direction::Outgoing) { + let child_index = edge.target(); + let child_node = self.graph.node_weight_mut(child_index).unwrap(); + if !child_node.inputs.is_empty() { + // Here cannot be an error + child_node.inputs[*edge.weight() as usize] = + DFGTaskInput::Value(o.1.as_ref().unwrap().0.clone()); + } } } self.graph[node_index].result = Some(o.1); @@ -268,7 +276,6 @@ impl<'a> Scheduler<'a> { } } - let (src, dest) = channel(); let rayon_threads = self.rayon_threads; @@ -288,12 +295,12 @@ impl<'a> Scheduler<'a> { )) .unwrap(); }); - }).await?; - + }) + .await?; + let results: Vec<_> = dest.iter().collect(); - for result in results { - let mut output = result?; - while let Some(o) = output.0.pop() { + for mut result in results { + while let Some(o) = result.0.pop() { let index = o.0; let node_index = NodeIndex::new(index); self.graph[node_index].result = Some(o.1); @@ -424,35 +431,45 @@ fn execute_partition( use_global_threadpool: bool, rayon_threads: usize, server_key: tfhe::ServerKey, -) -> Result<( - Vec<(usize, (SupportedFheCiphertexts, i16, Vec))>, +) -> ( + Vec<(usize, Result<(SupportedFheCiphertexts, i16, Vec)>)>, NodeIndex, -)> { - let mut res: HashMap)> = +) { + let mut res: HashMap)>> = HashMap::with_capacity(computations.len()); - for (opcode, inputs, nidx) in computations { + 'comps: for (opcode, inputs, nidx) in computations { let mut cts = Vec::with_capacity(inputs.len()); for i in inputs.iter() { match i { DFGTaskInput::Dependence(d) => { if let Some(d) = d { - if let Some(ct) = res.get(d) { + if let Some(Ok(ct)) = res.get(d) { cts.push(ct.0.clone()); + } else { + res.insert( + nidx.index(), + Err(SchedulerError::UnsatisfiedDependence.into()), + ); + continue 'comps; } - } else { - return Err(SchedulerError::UnsatisfiedDependence.into()); } } DFGTaskInput::Value(v) => { cts.push(v.clone()); } DFGTaskInput::Compressed((t, c)) => { - cts.push(SupportedFheCiphertexts::decompress(*t, c)?); + let decomp = SupportedFheCiphertexts::decompress(*t, c); + if let Ok(decomp) = decomp { + cts.push(decomp); + } else { + res.insert(nidx.index(), Err(decomp.err().unwrap())); + continue 'comps; + } } } } if use_global_threadpool { - let (node_index, result) = run_computation(opcode, Ok(cts), nidx.index())?; + let (node_index, result) = run_computation(opcode, Ok(cts), nidx.index()); res.insert(node_index, result); } else { let thread_pool = THREAD_POOL @@ -470,38 +487,41 @@ fn execute_partition( thread_pool.broadcast(|_| { tfhe::set_server_key(server_key.clone()); }); - thread_pool.install(|| -> Result<()> { - let (node_index, result) = run_computation(opcode, Ok(cts), nidx.index())?; + let _ = thread_pool.install(|| -> Result<()> { + let (node_index, result) = run_computation(opcode, Ok(cts), nidx.index()); res.insert(node_index, result); Ok(()) - })?; + }); THREAD_POOL.set(Some(thread_pool)); } } - Ok((Vec::from_iter(res), task_id)) + (Vec::from_iter(res), task_id) } fn run_computation( operation: i32, inputs: Result>, graph_node_index: usize, -) -> Result<(usize, (SupportedFheCiphertexts, i16, Vec))> { +) -> (usize, Result<(SupportedFheCiphertexts, i16, Vec)>) { let op = FheOperation::try_from(operation); match inputs { Ok(inputs) => match op { Ok(FheOperation::FheGetCiphertext) => { let (ct_type, ct_bytes) = inputs[0].compress(); - Ok((graph_node_index, (inputs[0].clone(), ct_type, ct_bytes))) + (graph_node_index, Ok((inputs[0].clone(), ct_type, ct_bytes))) } Ok(_) => match perform_fhe_operation(operation as i16, &inputs) { Ok(result) => { let (ct_type, ct_bytes) = result.compress(); - Ok((graph_node_index, (result.clone(), ct_type, ct_bytes))) + (graph_node_index, Ok((result.clone(), ct_type, ct_bytes))) } - Err(e) => Err(e.into()), + Err(e) => (graph_node_index, Err(e.into())), }, - _ => Err(SchedulerError::UnknownOperation(operation).into()), + _ => ( + graph_node_index, + Err(SchedulerError::UnknownOperation(operation).into()), + ), }, - Err(_) => Err(SchedulerError::InvalidInputs.into()), + Err(_) => (graph_node_index, Err(SchedulerError::InvalidInputs.into())), } } diff --git a/fhevm-engine/scheduler/src/dfg/types.rs b/fhevm-engine/scheduler/src/dfg/types.rs index 23706de9..fd628d32 100644 --- a/fhevm-engine/scheduler/src/dfg/types.rs +++ b/fhevm-engine/scheduler/src/dfg/types.rs @@ -1,6 +1,7 @@ +use anyhow::Result; use fhevm_engine_common::types::SupportedFheCiphertexts; -pub type DFGTaskResult = Option<(SupportedFheCiphertexts, i16, Vec)>; +pub type DFGTaskResult = Option)>>; #[derive(Clone)] pub enum DFGTaskInput {