Skip to content

Commit

Permalink
fix(scheduler): enable per FHE computation error return from scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
antoniupop committed Feb 6, 2025
1 parent 5923e8f commit d1311be
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 92 deletions.
32 changes: 21 additions & 11 deletions fhevm-engine/coprocessor/src/tfhe_worker.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use crate::types::CoprocessorError;
use crate::{db_queries::populate_cache_with_tenant_keys, types::TfheTenantKeys};
use fhevm_engine_common::types::{Handle, SupportedFheCiphertexts};
use fhevm_engine_common::types::{FhevmError, Handle, SupportedFheCiphertexts};
use fhevm_engine_common::{tfhe_ops::current_ciphertext_version, types::SupportedFheOperations};
use itertools::Itertools;
use lazy_static::lazy_static;
use opentelemetry::trace::{Span, TraceContextExt, Tracer};
use opentelemetry::KeyValue;
use prometheus::{register_int_counter, IntCounter};
use scheduler::dfg::types::SchedulerError;
use scheduler::dfg::{scheduler::Scheduler, types::DFGTaskInput, DFGraph};
use sqlx::{postgres::PgListener, query, Acquire};
use std::{
Expand Down Expand Up @@ -357,29 +359,37 @@ async fn tfhe_worker_cycle(
let keys = rk.get(tenant_id).expect("Can't get tenant key from cache");

// Schedule computations in parallel as dependences allow
tfhe::set_server_key(keys.sks.clone());
let mut sched = Scheduler::new(&mut graph.graph, args.coprocessor_fhe_threads);
sched.schedule(keys.sks.clone()).await?;
}
// Extract the results from the graph
let res = graph.get_results().unwrap();
let mut res = graph.get_results();

for (idx, w) in work.iter().enumerate() {
// Filter out computations that could not complete
if uncomputable.contains_key(&idx) {
continue;
}
let r = res.iter().find(|(h, _)| *h == w.output_handle).unwrap();
{
let mut rk = tenant_key_cache.write().await;
let keys = rk
.get(&w.tenant_id)
.expect("Can't get tenant key from cache");
tfhe::set_server_key(keys.sks.clone());
}
let r = &mut res
.iter_mut()
.find(|(h, _)| *h == w.output_handle)
.unwrap()
.1;

let finished_work_unit: Result<
_,
(Box<(dyn std::error::Error + Send + Sync)>, i32, Vec<u8>),
> = Ok((w, r.1 .1, &r.1 .2));
> = r
.as_mut()
.map(|rok| (w, rok.1, std::mem::take(&mut rok.2)))
.map_err(|_rerr| {
(
CoprocessorError::SchedulerError(SchedulerError::SchedulerError).into(),
w.tenant_id,
w.output_handle.clone(),
)
});
match finished_work_unit {
Ok((w, db_type, db_bytes)) => {
let mut s = tracer.start_with_context("insert_ct_into_db", &loop_ctx);
Expand Down
19 changes: 14 additions & 5 deletions fhevm-engine/executor/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,22 @@ impl FhevmExecutor for FhevmExecutorService {
now.elapsed().unwrap().as_millis()
);
// Extract the results from the graph
match graph.get_results() {
Ok(mut result_cts) => Some(Resp::ResultCiphertexts(ResultCiphertexts {
ciphertexts: result_cts
let results = graph.get_results();
let outputs: Result<Vec<(Handle, (SupportedFheCiphertexts, i16, Vec<u8>))>> =
results
.into_iter()
.map(|(h, output)| match output {
Ok(output) => Ok((h, output)),
Err(e) => Err(e),
})
.collect();
match outputs {
Ok(mut outputs) => Some(Resp::ResultCiphertexts(ResultCiphertexts {
ciphertexts: outputs
.iter_mut()
.map(|(h, ct)| CompressedCiphertext {
.map(|(h, output)| CompressedCiphertext {
handle: h.clone(),
serialization: std::mem::take(&mut ct.2),
serialization: std::mem::take(&mut output.2),
})
.collect(),
})),
Expand Down
13 changes: 8 additions & 5 deletions fhevm-engine/scheduler/src/dfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,19 @@ impl DFGraph {

pub fn get_results(
&mut self,
) -> Result<Vec<(Handle, (SupportedFheCiphertexts, i16, Vec<u8>))>> {
) -> Vec<(Handle, Result<(SupportedFheCiphertexts, i16, Vec<u8>)>)> {
let mut res = Vec::with_capacity(self.graph.node_count());
for index in 0..self.graph.node_count() {
let node = self.graph.node_weight_mut(NodeIndex::new(index)).unwrap();
if let Some(ct) = &node.result {
res.push((node.result_handle.clone(), ct.clone()));
if let Some(ct) = std::mem::take(&mut node.result) {
res.push((node.result_handle.clone(), ct));
} else {
return Err(SchedulerError::DataflowGraphError.into());
res.push((
node.result_handle.clone(),
Err(SchedulerError::DataflowGraphError.into()),
));
}
}
Ok(res)
res
}
}
160 changes: 90 additions & 70 deletions fhevm-engine/scheduler/src/dfg/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ impl<'a> Scheduler<'a> {
}

async fn schedule_fine_grain(&mut self, server_key: tfhe::ServerKey) -> Result<()> {
let mut set: JoinSet<Result<(usize, (SupportedFheCiphertexts, i16, Vec<u8>))>> =
let mut set: JoinSet<(usize, Result<(SupportedFheCiphertexts, i16, Vec<u8>)>)> =
JoinSet::new();
tfhe::set_server_key(server_key.clone());
// Prime the scheduler with all nodes without dependences
Expand Down Expand Up @@ -127,36 +127,39 @@ impl<'a> Scheduler<'a> {
}
// Get results from computations and update dependences of remaining computations
while let Some(result) = set.join_next().await {
let output = result??;
let index = output.0;
let result = result?;
let index = result.0;
let node_index = NodeIndex::new(index);
// Satisfy deps from the executed task
for edge in self.edges.edges_directed(node_index, Direction::Outgoing) {
let sks = server_key.clone();
let child_index = edge.target();
let child_node = self.graph.node_weight_mut(child_index).unwrap();
child_node.inputs[*edge.weight() as usize] =
DFGTaskInput::Value(output.1 .0.clone());
if Self::is_ready(child_node) {
let opcode = child_node.opcode;
let inputs: Result<Vec<SupportedFheCiphertexts>> = child_node
.inputs
.iter()
.map(|i| match i {
DFGTaskInput::Value(i) => Ok(i.clone()),
DFGTaskInput::Compressed((t, c)) => {
SupportedFheCiphertexts::decompress(*t, c)
}
_ => Err(SchedulerError::UnsatisfiedDependence.into()),
})
.collect();
set.spawn_blocking(move || {
tfhe::set_server_key(sks.clone());
run_computation(opcode, inputs, child_index.index())
});
if let Ok(output) = &result.1 {
// Satisfy deps from the executed task
for edge in self.edges.edges_directed(node_index, Direction::Outgoing) {
let sks = server_key.clone();
let child_index = edge.target();
let child_node = self.graph.node_weight_mut(child_index).unwrap();
child_node.inputs[*edge.weight() as usize] =
DFGTaskInput::Value(output.0.clone());
if Self::is_ready(child_node) {
let opcode = child_node.opcode;
let inputs: Result<Vec<SupportedFheCiphertexts>> = child_node
.inputs
.iter()
.map(|i| match i {
DFGTaskInput::Value(i) => Ok(i.clone()),
DFGTaskInput::Compressed((t, c)) => {
SupportedFheCiphertexts::decompress(*t, c)
}
_ => Err(SchedulerError::UnsatisfiedDependence.into()),
})
.collect();
set.spawn_blocking(move || {
tfhe::set_server_key(sks.clone());
run_computation(opcode, inputs, child_index.index())
});
}
}
}
self.graph[node_index].result = Some(output.1);
let node_index = NodeIndex::new(result.0);
self.graph[node_index].result = Some(result.1);
}
Ok(())
}
Expand All @@ -167,12 +170,10 @@ impl<'a> Scheduler<'a> {
server_key: tfhe::ServerKey,
) -> Result<()> {
tfhe::set_server_key(server_key.clone());
let mut set: JoinSet<
Result<(
Vec<(usize, (SupportedFheCiphertexts, i16, Vec<u8>))>,
NodeIndex,
)>,
> = JoinSet::new();
let mut set: JoinSet<(
Vec<(usize, Result<(SupportedFheCiphertexts, i16, Vec<u8>)>)>,
NodeIndex,
)> = JoinSet::new();
let mut execution_graph: Dag<ExecNode, ()> = Dag::default();
let _ = match strategy {
PartitionStrategy::MaxLocality => {
Expand Down Expand Up @@ -205,18 +206,25 @@ impl<'a> Scheduler<'a> {
}
// Get results from computations and update dependences of remaining computations
while let Some(result) = set.join_next().await {
let mut output = result??;
let task_index = output.1;
while let Some(o) = output.0.pop() {
let mut result = result?;
let task_index = result.1;
while let Some(o) = result.0.pop() {
let index = o.0;
let node_index = NodeIndex::new(index);
// Satisfy deps from the executed computation in the DFG
for edge in self.edges.edges_directed(node_index, Direction::Outgoing) {
let child_index = edge.target();
let child_node = self.graph.node_weight_mut(child_index).unwrap();
if !child_node.inputs.is_empty() {
child_node.inputs[*edge.weight() as usize] =
DFGTaskInput::Value(o.1 .0.clone());
// If this node result is an error, we can't satisfy
// any dependences with it, so skip - all dependences
// on this will remain unsatisfied and result in
// further errors.
if o.1.is_ok() {
// Satisfy deps from the executed computation in the DFG
for edge in self.edges.edges_directed(node_index, Direction::Outgoing) {
let child_index = edge.target();
let child_node = self.graph.node_weight_mut(child_index).unwrap();
if !child_node.inputs.is_empty() {
// Here cannot be an error
child_node.inputs[*edge.weight() as usize] =
DFGTaskInput::Value(o.1.as_ref().unwrap().0.clone());
}
}
}
self.graph[node_index].result = Some(o.1);
Expand Down Expand Up @@ -268,7 +276,6 @@ impl<'a> Scheduler<'a> {
}
}


let (src, dest) = channel();
let rayon_threads = self.rayon_threads;

Expand All @@ -288,12 +295,12 @@ impl<'a> Scheduler<'a> {
))
.unwrap();
});
}).await?;

})
.await?;

let results: Vec<_> = dest.iter().collect();
for result in results {
let mut output = result?;
while let Some(o) = output.0.pop() {
for mut result in results {
while let Some(o) = result.0.pop() {
let index = o.0;
let node_index = NodeIndex::new(index);
self.graph[node_index].result = Some(o.1);
Expand Down Expand Up @@ -424,35 +431,45 @@ fn execute_partition(
use_global_threadpool: bool,
rayon_threads: usize,
server_key: tfhe::ServerKey,
) -> Result<(
Vec<(usize, (SupportedFheCiphertexts, i16, Vec<u8>))>,
) -> (
Vec<(usize, Result<(SupportedFheCiphertexts, i16, Vec<u8>)>)>,
NodeIndex,
)> {
let mut res: HashMap<usize, (SupportedFheCiphertexts, i16, Vec<u8>)> =
) {
let mut res: HashMap<usize, Result<(SupportedFheCiphertexts, i16, Vec<u8>)>> =
HashMap::with_capacity(computations.len());
for (opcode, inputs, nidx) in computations {
'comps: for (opcode, inputs, nidx) in computations {
let mut cts = Vec::with_capacity(inputs.len());
for i in inputs.iter() {
match i {
DFGTaskInput::Dependence(d) => {
if let Some(d) = d {
if let Some(ct) = res.get(d) {
if let Some(Ok(ct)) = res.get(d) {
cts.push(ct.0.clone());
} else {
res.insert(
nidx.index(),
Err(SchedulerError::UnsatisfiedDependence.into()),
);
continue 'comps;
}
} else {
return Err(SchedulerError::UnsatisfiedDependence.into());
}
}
DFGTaskInput::Value(v) => {
cts.push(v.clone());
}
DFGTaskInput::Compressed((t, c)) => {
cts.push(SupportedFheCiphertexts::decompress(*t, c)?);
let decomp = SupportedFheCiphertexts::decompress(*t, c);
if let Ok(decomp) = decomp {
cts.push(decomp);
} else {
res.insert(nidx.index(), Err(decomp.err().unwrap()));
continue 'comps;
}
}
}
}
if use_global_threadpool {
let (node_index, result) = run_computation(opcode, Ok(cts), nidx.index())?;
let (node_index, result) = run_computation(opcode, Ok(cts), nidx.index());
res.insert(node_index, result);
} else {
let thread_pool = THREAD_POOL
Expand All @@ -470,38 +487,41 @@ fn execute_partition(
thread_pool.broadcast(|_| {
tfhe::set_server_key(server_key.clone());
});
thread_pool.install(|| -> Result<()> {
let (node_index, result) = run_computation(opcode, Ok(cts), nidx.index())?;
let _ = thread_pool.install(|| -> Result<()> {
let (node_index, result) = run_computation(opcode, Ok(cts), nidx.index());
res.insert(node_index, result);
Ok(())
})?;
});
THREAD_POOL.set(Some(thread_pool));
}
}
Ok((Vec::from_iter(res), task_id))
(Vec::from_iter(res), task_id)
}

fn run_computation(
operation: i32,
inputs: Result<Vec<SupportedFheCiphertexts>>,
graph_node_index: usize,
) -> Result<(usize, (SupportedFheCiphertexts, i16, Vec<u8>))> {
) -> (usize, Result<(SupportedFheCiphertexts, i16, Vec<u8>)>) {
let op = FheOperation::try_from(operation);
match inputs {
Ok(inputs) => match op {
Ok(FheOperation::FheGetCiphertext) => {
let (ct_type, ct_bytes) = inputs[0].compress();
Ok((graph_node_index, (inputs[0].clone(), ct_type, ct_bytes)))
(graph_node_index, Ok((inputs[0].clone(), ct_type, ct_bytes)))
}
Ok(_) => match perform_fhe_operation(operation as i16, &inputs) {
Ok(result) => {
let (ct_type, ct_bytes) = result.compress();
Ok((graph_node_index, (result.clone(), ct_type, ct_bytes)))
(graph_node_index, Ok((result.clone(), ct_type, ct_bytes)))
}
Err(e) => Err(e.into()),
Err(e) => (graph_node_index, Err(e.into())),
},
_ => Err(SchedulerError::UnknownOperation(operation).into()),
_ => (
graph_node_index,
Err(SchedulerError::UnknownOperation(operation).into()),
),
},
Err(_) => Err(SchedulerError::InvalidInputs.into()),
Err(_) => (graph_node_index, Err(SchedulerError::InvalidInputs.into())),
}
}
3 changes: 2 additions & 1 deletion fhevm-engine/scheduler/src/dfg/types.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use anyhow::Result;
use fhevm_engine_common::types::SupportedFheCiphertexts;

pub type DFGTaskResult = Option<(SupportedFheCiphertexts, i16, Vec<u8>)>;
pub type DFGTaskResult = Option<Result<(SupportedFheCiphertexts, i16, Vec<u8>)>>;

#[derive(Clone)]
pub enum DFGTaskInput {
Expand Down

0 comments on commit d1311be

Please sign in to comment.