Skip to content

Commit

Permalink
Optimize checks for malicious
Browse files Browse the repository at this point in the history
  • Loading branch information
jimouris committed Mar 8, 2023
1 parent 7e7e0a0 commit 0aeea65
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 33 deletions.
28 changes: 14 additions & 14 deletions src/bin/hh_leader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,17 +194,16 @@ async fn add_keys(

let mut addkeys_0 = vec![Vec::with_capacity(nreqs); 3];
let mut addkeys_1 = vec![Vec::with_capacity(nreqs); 3];
for _ in 0..nreqs {
let idx = zipf.sample(&mut rng) - 1;
for r in 0..nreqs {
let idx_1 = zipf.sample(&mut rng) - 1;
let mut idx_2 = idx_1;
if rand::thread_rng().gen_range(0.0..1.0) < malicious_percentage {
idx_2 += 1;
println!("Malicious {}", r);
}
for i in 0..3 {
addkeys_0[i].push(keys[i].0[idx].clone());
if rand::thread_rng().gen_range(0.0..1.0) < malicious_percentage {
// if r == 0 {
println!("Malicious");
addkeys_1[i].push(keys[i].1[(idx+1) % cfg.unique_buckets].clone());
} else {
addkeys_1[i].push(keys[i].1[idx].clone());
}
addkeys_0[i].push(keys[i].0[idx_1].clone());
addkeys_1[i].push(keys[i].1[idx_2 % cfg.unique_buckets].clone());
}
}

Expand Down Expand Up @@ -358,7 +357,8 @@ async fn run_level(

join_all(responses).await;

let ((vals0, root0), (vals1, root1)) = (response_00.await?.unwrap(), response_01.await?.unwrap());
let ((vals0, root0, indices0), (vals1, root1, _)) =
(response_00.await?.unwrap(), response_01.await?.unwrap());
debug_assert_eq!(vals0.len(), vals1.len());
keep = collect::KeyCollection::<fastfield::FE,FieldElm>::keep_values_cmp(&threshold, &vals0, &vals1);

Expand All @@ -372,14 +372,14 @@ async fn run_level(
let hl0 = &left_root0[i].iter().map(|x| format!("{:02x}", x)).collect::<String>();
let hl1 = &left_root1[i].iter().map(|x| format!("{:02x}", x)).collect::<String>();
if hl0 != hl1 {
malicious.push(i);
// println!("{}) left different {} vs {}", i, hl0, hl1);
malicious.push(indices0[0][i]);
// println!("{}) different {} vs {}", i, hl0, hl1);
}
}
if malicious.len() == 0 {
break;
} else {
println!("Detected malicious {:?} out of {} clients", malicious, nreqs);
// println!("Detected malicious {:?} out of {} clients", malicious, nreqs);
if split > nreqs {
if !is_last {
is_last = true;
Expand Down
2 changes: 1 addition & 1 deletion src/bin/hh_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ impl Collector for BatchCollectorServer {

async fn tree_crawl(self,
_: context::Context, req: HHTreeCrawlRequest
) -> (Vec<FE>, Vec<Vec<Vec<u8>>>) {
) -> (Vec<FE>, Vec<Vec<Vec<u8>>>, Vec<Vec<usize>>) {
// let start = Instant::now();
let client_idx = req.client_idx as usize;
let split_by = req.split_by;
Expand Down
69 changes: 52 additions & 17 deletions src/collect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ struct TreeNode<T> {
key_states: Vec<dpf::EvalState>,
key_values: Vec<T>,
hashes: Vec<Vec<u8>>,
indices: Vec<usize>
}

unsafe impl<T> Send for TreeNode<T> {}
Expand Down Expand Up @@ -116,6 +117,7 @@ where
key_states: vec![],
key_values: vec![],
hashes: vec![],
indices: vec![],
};

for k in &self.keys {
Expand All @@ -130,7 +132,12 @@ where
}

fn make_tree_node(
&self, parent: &TreeNode<T>, dir: bool, split_by: usize
&self,
parent: &TreeNode<T>,
dir: bool,
split_by: usize,
malicious_indices: &Vec<usize>,
is_last: bool,
) -> TreeNode<T> {
let (key_states, key_values): (Vec<dpf::EvalState>, Vec<T>) = self
.keys
Expand Down Expand Up @@ -178,17 +185,41 @@ where
}
})
.collect::<Vec<_>>();
// let mut root = [0u8; 32];
let mut roots = Vec::new();
if bit_str.len() < 2 {
let chunk_sz = (hashes.len() as f32 / split_by as f32).ceil() as usize;
// println!("split_by: {}", chunk_sz);
let mut root_indices = Vec::new();
if bit_str.len() < 2 && !is_last {
let tree_size = 1 << (hashes.len() as f32).log2().ceil() as usize;
let chunk_sz = tree_size/ split_by;
let chunks_list: Vec<&[[u8; 32]]> = hashes.chunks(chunk_sz).collect();
// println!("chunks_list len: {}, each of which are {}", chunks_list.len(), chunks_list[0].len());
for i in 0..chunks_list.len() {
let mt = MerkleTree::<Sha256Algorithm>::from_leaves(chunks_list[i]);
// println!("hashes.len(): {}\nchunk_sz: {}\ntree size: {}\nsplit_by: {}\n{} chunks, each of which has {} elements\nmalicious {:?}\n",
// hashes.len(),
// chunk_sz,
// tree_size,
// split_by,
// chunks_list.len(),
// chunks_list[0].len(),
// malicious_indices
// );
if split_by == 1 {
let mt = MerkleTree::<Sha256Algorithm>::from_leaves(chunks_list[0]);
let root = mt.root().unwrap();
roots.push(root.to_vec());
root_indices.push(0);
} else {
for &i in malicious_indices {
let mt_left = MerkleTree::<Sha256Algorithm>::from_leaves(chunks_list[i * 2]);
let root_left = mt_left.root().unwrap();
roots.push(root_left.to_vec());
root_indices.push(i * 2);

if i * 2 + 1 >= chunks_list.len() {
continue;
}
let mt_right = MerkleTree::<Sha256Algorithm>::from_leaves(chunks_list[i * 2 + 1]);
let root_right = mt_right.root().unwrap();
roots.push(root_right.to_vec());
root_indices.push(i * 2 + 1);
}
}
}

Expand All @@ -198,6 +229,7 @@ where
key_states,
key_values,
hashes: roots,
indices: root_indices,
};

child.path.push(dir);
Expand Down Expand Up @@ -274,12 +306,13 @@ where
key_states: vec![],
key_values: vec![],
hashes: vec![],
indices: vec![],
}
}

pub fn hh_tree_crawl(
&mut self, session_index: usize, split_by: usize, malicious: &Vec<usize>, is_last: bool
) -> (Vec<T>, Vec<Vec<Vec<u8>>>) {
) -> (Vec<T>, Vec<Vec<Vec<u8>>>, Vec<Vec<usize>>) {
if malicious.len() > 0 {
if is_last {
for &malicious_client in malicious {
Expand All @@ -298,8 +331,8 @@ where
.par_iter()
.map(|node| {
assert!(node.path.len() <= self.depth);
let child0 = self.make_tree_node(node, false, split_by);
let child1 = self.make_tree_node(node, true, split_by);
let child0 = self.make_tree_node(node, false, split_by, malicious, is_last);
let child1 = self.make_tree_node(node, true, split_by, malicious, is_last);

vec![child0, child1]
})
Expand All @@ -311,15 +344,15 @@ where
.map(|node| node.value.clone())
.collect::<Vec<T>>();

let hashes = next_frontier
let (hashes, indices) = next_frontier
.par_iter()
.map(|node| node.hashes.clone())
.collect::<Vec<_>>();
.map(|node| (node.hashes.clone(), node.indices.clone()))
.collect::<(Vec<_>, Vec<_>)>();

self.prev_frontier = self.frontier.clone();
self.frontier = next_frontier;

(values, hashes)
(values, hashes, indices)
}

pub fn histogram_tree_crawl(&mut self) {
Expand All @@ -328,8 +361,8 @@ where
.par_iter()
.map(|node| {
// assert!(node.path.len() <= self.depth);
let child0 = self.make_tree_node(node, false, 1);
let child1 = self.make_tree_node(node, true, 1);
let child0 = self.make_tree_node(node, false, 1, &vec![], false);
let child1 = self.make_tree_node(node, true, 1, &vec![], false);

vec![child0, child1]
})
Expand All @@ -354,13 +387,15 @@ where
key_states: vec![],
key_values: key_values_l,
hashes: hashes_l,
indices: vec![],
},
TreeNode::<U> {
path: path_r,
value: U::zero(),
key_states: vec![],
key_values: key_values_r,
hashes: hashes_r,
indices: vec![],
})
})
.collect::<Vec<(TreeNode<U>, TreeNode<U>)>>();
Expand Down
2 changes: 1 addition & 1 deletion src/hh_rpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ pub trait Collector {
async fn reset(rst: HHResetRequest) -> String;
async fn add_keys(add: HHAddKeysRequest) -> String;
async fn tree_init(req: HHTreeInitRequest) -> String;
async fn tree_crawl(req: HHTreeCrawlRequest) -> (Vec<FE>, Vec<Vec<Vec<u8>>>);
async fn tree_crawl(req: HHTreeCrawlRequest) -> (Vec<FE>, Vec<Vec<Vec<u8>>>, Vec<Vec<usize>>);
async fn tree_crawl_last(req: HHTreeCrawlLastRequest) -> (Vec<Vec<u8>>, Vec<FieldElm>);
async fn tree_prune(req: HHTreePruneRequest) -> String;
async fn tree_prune_last(req: HHTreePruneLastRequest) -> String;
Expand Down

0 comments on commit 0aeea65

Please sign in to comment.