Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: remove computations.output_type #291

Merged
merged 1 commit into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 0 additions & 47 deletions fhevm-engine/coprocessor/src/db_queries.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use std::collections::{BTreeSet, HashMap};
use std::str::FromStr;
use std::sync::Arc;

Expand Down Expand Up @@ -58,52 +57,6 @@ pub async fn check_if_api_key_is_valid<T>(
}
}

/// Returns ciphertext types
pub async fn check_if_ciphertexts_exist_in_db(
mut cts: BTreeSet<Vec<u8>>,
tenant_id: i32,
pool: &sqlx::Pool<Postgres>,
) -> Result<HashMap<Vec<u8>, i16>, CoprocessorError> {
let handles_to_check_in_db_vec = cts.iter().cloned().collect::<Vec<_>>();
let ciphertexts = query!(
r#"
-- existing computations
SELECT handle AS "handle!", ciphertext_type AS "ciphertext_type!"
FROM ciphertexts
WHERE tenant_id = $2
AND handle = ANY($1::BYTEA[])
UNION
-- pending computations
SELECT output_handle AS "handle!", output_type AS "ciphertext_type!"
FROM computations
WHERE tenant_id = $2
AND output_handle = ANY($1::BYTEA[])
"#,
&handles_to_check_in_db_vec,
tenant_id,
)
.fetch_all(pool)
.await
.map_err(Into::<CoprocessorError>::into)?;

let mut result = HashMap::with_capacity(cts.len());
for ct in ciphertexts {
assert!(cts.remove(&ct.handle), "any ciphertext selected must exist");
assert!(result
.insert(ct.handle.clone(), ct.ciphertext_type)
.is_none());
}

if !cts.is_empty() {
return Err(CoprocessorError::UnexistingInputCiphertextsFound(
cts.into_iter()
.map(|i| format!("0x{}", hex::encode(i)))
.collect(),
));
}

Ok(result)
}

pub struct FetchTenantKeyResult {
pub chain_id: i32,
Expand Down
32 changes: 6 additions & 26 deletions fhevm-engine/coprocessor/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::num::NonZeroUsize;
use std::str::FromStr;

use crate::db_queries::{
check_if_api_key_is_valid, check_if_ciphertexts_exist_in_db, fetch_tenant_server_key,
check_if_api_key_is_valid, fetch_tenant_server_key,
};
use crate::server::coprocessor::GenericResponse;
use crate::types::{CoprocessorError, TfheTenantKeys};
Expand Down Expand Up @@ -626,22 +626,17 @@ impl CoprocessorService {
let mut span = tracer.child_span("sort_computations_by_dependencies");
// computations are now sorted based on dependencies or error should have
// been returned if there's circular dependency
let (sorted_computations, handles_to_check_in_db) =
let (sorted_computations, _handles_to_check_in_db) =
sort_computations_by_dependencies(&req.computations)?;
span.end();

// to insert to db
let mut span = tracer.child_span("check_if_ciphertexts_exist_in_db");
let mut ct_types =
check_if_ciphertexts_exist_in_db(handles_to_check_in_db, tenant_id, &self.pool).await?;
span.end();
let mut computations_inputs: Vec<Vec<Vec<u8>>> =
Vec::with_capacity(sorted_computations.len());
let mut computations_outputs: Vec<Vec<u8>> = Vec::with_capacity(sorted_computations.len());
let mut are_comps_scalar: Vec<bool> = Vec::with_capacity(sorted_computations.len());
for comp in &sorted_computations {
computations_outputs.push(comp.output_handle.clone());
let mut handle_types = Vec::with_capacity(comp.inputs.len());
let mut is_computation_scalar = false;
let mut this_comp_inputs: Vec<Vec<u8>> = Vec::with_capacity(comp.inputs.len());
let mut is_scalar_op_vec: Vec<bool> = Vec::with_capacity(comp.inputs.len());
Expand All @@ -653,16 +648,11 @@ impl CoprocessorService {
if let Some(input) = &ih.input {
match input {
Input::InputHandle(ih) => {
let ct_type = ct_types
.get(ih)
.expect("this must be found if operand is non scalar");
handle_types.push(*ct_type);
this_comp_inputs.push(ih.clone());
is_scalar_op_vec.push(false);
}
Input::Scalar(sc) => {
is_computation_scalar = true;
handle_types.push(-1);
this_comp_inputs.push(sc.clone());
is_scalar_op_vec.push(true);
assert!(idx == 1 || fhe_op.does_have_more_than_one_scalar(), "we should have checked earlier that only second operand can be scalar");
Expand All @@ -673,20 +663,15 @@ impl CoprocessorService {

// check before we insert computation that it has
// to succeed according to the type system
let output_type = check_fhe_operand_types(
check_fhe_operand_types(
comp.operation,
&handle_types,
&this_comp_inputs,
&is_scalar_op_vec,
)
.map_err(|e| CoprocessorError::FhevmError(e))?;

computations_inputs.push(this_comp_inputs);
are_comps_scalar.push(is_computation_scalar);
// fill in types with output handles that are computed as we go
assert!(ct_types
.insert(comp.output_handle.clone(), output_type)
.is_none());
}

let mut tx_span = tracer.child_span("db_transaction");
Expand All @@ -698,9 +683,6 @@ impl CoprocessorService {

let mut new_work_available = false;
for (idx, comp) in sorted_computations.iter().enumerate() {
let output_type = ct_types
.get(&comp.output_handle)
.expect("we should have collected all output result types by now with check_fhe_operand_types");
let fhe_operation: i16 = comp.operation.try_into().map_err(|_| {
CoprocessorError::FhevmError(FhevmError::UnknownFheOperation(comp.operation))
})?;
Expand All @@ -717,18 +699,16 @@ impl CoprocessorService {
dependencies,
fhe_operation,
is_completed,
is_scalar,
output_type
is_scalar
)
VALUES($1, $2, $3, $4, false, $5, $6)
VALUES($1, $2, $3, $4, false, $5)
ON CONFLICT (tenant_id, output_handle) DO NOTHING
",
tenant_id,
comp.output_handle,
&computations_inputs[idx],
fhe_operation,
are_comps_scalar[idx],
output_type
are_comps_scalar[idx]
)
.execute(trx.as_mut())
.await
Expand Down
33 changes: 6 additions & 27 deletions fhevm-engine/coprocessor/src/tests/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -412,15 +412,8 @@ async fn test_coprocessor_computation_errors() -> Result<(), Box<dyn std::error:
MetadataValue::from_str(&api_key_header).unwrap(),
);
match client.async_compute(input_request).await {
Ok(_) => {
panic!("Expected failure")
}
Err(e) => {
eprintln!("error: {}", e);
assert!(e
.to_string()
.contains("fhevm error: FheOperationDoesntHaveUniformTypesAsInput"));
}
Ok(_) => (),
Err(_e) => panic!("No type error detections."),
}
}

Expand Down Expand Up @@ -623,15 +616,8 @@ async fn test_coprocessor_computation_errors() -> Result<(), Box<dyn std::error:
MetadataValue::from_str(&api_key_header).unwrap(),
);
match client.async_compute(input_request).await {
Ok(_) => {
panic!("Expected failure")
}
Err(e) => {
eprintln!("error: {}", e);
assert!(e
.to_string()
.contains("fhevm error: OperationDoesntSupportBooleanInputs"));
}
Ok(_) => (),
Err(_e) => panic!("No type error detections."),
}
}

Expand All @@ -653,15 +639,8 @@ async fn test_coprocessor_computation_errors() -> Result<(), Box<dyn std::error:
MetadataValue::from_str(&api_key_header).unwrap(),
);
match client.async_compute(input_request).await {
Ok(_) => {
panic!("Expected failure")
}
Err(e) => {
eprintln!("error: {}", e);
assert!(e
.to_string()
.contains("fhevm error: OperationDoesntSupportBooleanInputs"));
}
Ok(_) => (),
Err(_e) => panic!("No type error detections."),
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ALTER TABLE computations
DROP COLUMN IF EXISTS output_type;
Loading
Loading