Skip to content

Commit

Permalink
Updated server/client direct interaction to leverage my classes objec…
Browse files Browse the repository at this point in the history
…t streams. I need to get something similar going with level-sites
  • Loading branch information
AndrewQuijano committed Jan 14, 2024
1 parent 291b269 commit 43b147d
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 82 deletions.
65 changes: 27 additions & 38 deletions src/main/java/weka/finito/client.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import java.security.KeyPair;
import java.security.NoSuchAlgorithmException;
import java.util.HashMap;
import java.util.Hashtable;

import java.lang.System;

import security.dgk.DGKKeyPairGenerator;
Expand Down Expand Up @@ -43,8 +41,7 @@ public final class client implements Runnable {

private KeyPair dgk;
private KeyPair paillier;
private Hashtable<String, BigIntegers> feature = null;

private HashMap<String, BigIntegers> feature = null;
private boolean classification_complete = false;
private String [] classes;

Expand Down Expand Up @@ -249,7 +246,7 @@ private void setup_with_server_site(PaillierPublicKey paillier, DGKPublicKey dgk
}

// Evaluation
private Hashtable<String, BigIntegers> read_features(String path,
private HashMap<String, BigIntegers> read_features(String path,
PaillierPublicKey paillier_public_key,
DGKPublicKey dgk_public_key,
int precision)
Expand All @@ -258,7 +255,7 @@ private Hashtable<String, BigIntegers> read_features(String path,
BigInteger integerValuePaillier;
BigInteger integerValueDGK;
int intermediateInteger;
Hashtable<String, BigIntegers> values = new Hashtable<>();
HashMap<String, BigIntegers> values = new HashMap<>();
try (BufferedReader br = new BufferedReader(new FileReader(path))) {
String line;

Expand Down Expand Up @@ -287,27 +284,25 @@ private Hashtable<String, BigIntegers> read_features(String path,
}
}

private void evaluate_with_server_site(Socket server_site) throws IOException, HomomorphicException, ClassNotFoundException {
private void evaluate_with_server_site(Socket server_site)
throws IOException, HomomorphicException, ClassNotFoundException {
// Communicate with each Level-Site
Object o;
bob_joye client;

// Create I/O streams
ObjectOutputStream to_server_site = new ObjectOutputStream(server_site.getOutputStream());
ObjectInputStream from_server_site = new ObjectInputStream(server_site.getInputStream());

// Send the encrypted data to Level-Site
to_server_site.writeObject(this.feature);
to_server_site.flush();

// Send the Public Keys using Alice and Bob
client = new bob_joye(paillier, dgk, null);
client.set_socket(server_site);

// Work with the comparison
// Send the encrypted data to Level-Site
ObjectOutputStream oos = new ObjectOutputStream(server_site.getOutputStream());
oos.writeObject(this.feature);
oos.flush();

// Yup, I need the while loop here because all level-sites are at server
int comparison_type;
while(true) {
comparison_type = from_server_site.readInt();
while (true) {
comparison_type = client.readInt();
if (comparison_type == -1) {
this.classification_complete = true;
break;
Expand All @@ -321,7 +316,7 @@ else if (comparison_type == 1) {
client.Protocol2();
}

o = from_server_site.readObject();
o = client.readObject();
if (o instanceof String) {
classification = (String) o;
classification = hashed_classification.get(classification);
Expand Down Expand Up @@ -353,26 +348,20 @@ private void communicate_with_level_site(Socket level_site)
to_level_site.writeInt(next_index);
to_level_site.flush();

// Work with the comparison
int comparison_type;
while(true) {
comparison_type = from_level_site.readInt();
if (comparison_type == -2) {
System.out.println("LEVEL-SITE DOESN'T HAVE DATA!!!");
this.classification_complete = true;
return;
}
else if (comparison_type == -1) {
break;
}
else if (comparison_type == 0) {
client.setDGKMode(false);
}
else if (comparison_type == 1) {
client.setDGKMode(true);
}
client.Protocol2();
// Get the comparison
int comparison_type = from_level_site.readInt();
if (comparison_type == -1) {
System.out.println("LEVEL-SITE DOESN'T HAVE DATA!!!");
this.classification_complete = true;
return;
}
else if (comparison_type == 0) {
client.setDGKMode(false);
}
else if (comparison_type == 1) {
client.setDGKMode(true);
}
client.Protocol2();

// Get boolean from level-site:
// true - get leaf value
Expand Down
23 changes: 7 additions & 16 deletions src/main/java/weka/finito/level_site_evaluation_thread.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.net.Socket;
import java.util.Hashtable;
import java.util.HashMap;
import java.util.Map;

import static weka.finito.utils.shared.*;
Expand All @@ -20,19 +20,19 @@ public class level_site_evaluation_thread implements Runnable {

private final Socket client_socket;
private final level_order_site level_site_data;
private final Hashtable<String, BigIntegers> encrypted_features = new Hashtable<>();
private final HashMap<String, BigIntegers> encrypted_features = new HashMap<>();

// This thread is ONLY to handle evaluations
public level_site_evaluation_thread(Socket client_socket, level_order_site level_site_data, Hashtable x) {
public level_site_evaluation_thread(Socket client_socket, level_order_site level_site_data, HashMap x) {
// Have encrypted copy of thresholds if not done already for all nodes in level-site
this.level_site_data = level_site_data;
this.client_socket = client_socket;

for (Map.Entry<?, ?> entry: ((Hashtable<?, ?>) x).entrySet()) {
for (Map.Entry<?, ?> entry: ((HashMap<?, ?>) x).entrySet()) {
if (entry.getKey() instanceof String && entry.getValue() instanceof BigIntegers) {
encrypted_features.put((String) entry.getKey(), (BigIntegers) entry.getValue());
}
}
}
}

// This will run the communication with client and next level site
Expand All @@ -45,23 +45,14 @@ public final void run() {
ois = new ObjectInputStream(client_socket.getInputStream());
oos = new ObjectOutputStream(client_socket.getOutputStream());
niu.set_socket(client_socket);
if (this.level_site_data == null) {
oos.writeInt(-2);
closeConnection(oos, ois, client_socket);
return;
}

niu.setDGKPublicKey(this.level_site_data.dgk_public_key);
niu.setPaillierPublicKey(this.level_site_data.paillier_public_key);

level_site_data.set_current_index(ois.readInt());

// Null, keep going down the tree,
// Not null, you got the correct leaf node of your DT!
NodeInfo reply = traverse_level(level_site_data, encrypted_features, oos, niu);

// Place -1 to break Protocol4 loop
oos.writeInt(-1);
oos.flush();
NodeInfo reply = traverse_level(level_site_data, encrypted_features, niu);

if (reply != null) {
// Tell the client the value
Expand Down
6 changes: 3 additions & 3 deletions src/main/java/weka/finito/level_site_server.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import java.io.ObjectOutputStream;

import java.lang.System;
import java.util.Hashtable;
import java.util.HashMap;

import static weka.finito.utils.shared.*;

Expand Down Expand Up @@ -88,9 +88,9 @@ public void run() {
oos.writeBoolean(true);
closeConnection(oos, ois, client_socket);
}
else if (o instanceof Hashtable) {
else if (o instanceof HashMap) {
// Start evaluating with the client
Hashtable x = (Hashtable) o;
HashMap x = (HashMap) o;
level_site_evaluation_thread current_level_site_class = new level_site_evaluation_thread(client_socket,
this.level_site_parameters, x);
new Thread(current_level_site_class).start();
Expand Down
28 changes: 13 additions & 15 deletions src/main/java/weka/finito/server.java
Original file line number Diff line number Diff line change
Expand Up @@ -149,25 +149,24 @@ private void run_server_site(int port) throws IOException, HomomorphicException,
private void evaluate_with_client_directly(SSLSocket client_site)
throws IOException, HomomorphicException, ClassNotFoundException {

ObjectOutputStream to_client_site = new ObjectOutputStream(client_site.getOutputStream());
ObjectInputStream from_client_site = new ObjectInputStream(client_site.getInputStream());

Object client_input;
Hashtable<String, BigIntegers> features = new Hashtable<>();
HashMap<String, BigIntegers> features = new HashMap<>();

alice_joye Niu = new alice_joye();
Niu.set_socket(client_site);
Niu.setPaillierPublicKey(paillier_public);
Niu.setDGKPublicKey(dgk_public);

// Get encrypted features
client_input = from_client_site.readObject();
if (client_input instanceof Hashtable) {
for (Entry<?, ?> entry: ((Hashtable<?, ?>) client_input).entrySet()){
ObjectInputStream ois = new ObjectInputStream(client_site.getInputStream());
client_input = ois.readObject();
if (client_input instanceof HashMap) {
for (Entry<?, ?> entry: ((HashMap<?, ?>) client_input).entrySet()){
if (entry.getKey() instanceof String && entry.getValue() instanceof BigIntegers) {
features.put((String) entry.getKey(), (BigIntegers) entry.getValue());
}
}
}
alice_joye Niu = new alice_joye();
Niu.set_socket(client_site);
Niu.setPaillierPublicKey(paillier_public);
Niu.setDGKPublicKey(dgk_public);

long start_time = System.nanoTime();
int previous_index = 0;
Expand All @@ -177,14 +176,13 @@ private void evaluate_with_client_directly(SSLSocket client_site)
level_site_data.set_current_index(previous_index);

// Handle at a level...
NodeInfo leaf = traverse_level(level_site_data, features, to_client_site, Niu);
NodeInfo leaf = traverse_level(level_site_data, features, Niu);

// You found a leaf! No more traversing needed!
if (leaf != null) {
// Tell the client the value
to_client_site.writeInt(-1);
to_client_site.writeObject(leaf.getVariableName());
to_client_site.flush();
Niu.writeInt(-1);
Niu.writeObject(leaf.getVariableName());
long stop_time = System.nanoTime();
double run_time = (double) (stop_time - start_time);
run_time = run_time / 1000000;
Expand Down
18 changes: 8 additions & 10 deletions src/main/java/weka/finito/utils/shared.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Base64;
import java.util.Hashtable;
import java.util.HashMap;
import java.util.List;
import java.util.Properties;

Expand Down Expand Up @@ -69,8 +69,8 @@ public static void setup_tls() {
}

public static NodeInfo traverse_level(level_order_site level_site_data,
Hashtable<String, BigIntegers> encrypted_features,
ObjectOutputStream toClient, alice niu)
HashMap<String, BigIntegers> encrypted_features,
alice niu)
throws HomomorphicException, IOException, ClassNotFoundException {

List<NodeInfo> node_level_data = level_site_data.get_node_data();
Expand Down Expand Up @@ -100,12 +100,12 @@ public static NodeInfo traverse_level(level_order_site level_site_data,

if (ls.comparisonType == 6) {
inequalityHolds = compare(ls, 1,
encrypted_features, toClient, niu);
encrypted_features, niu);
inequalityHolds = !inequalityHolds;
}
else {
inequalityHolds = compare(ls, ls.comparisonType,
encrypted_features, toClient, niu);
encrypted_features, niu);
}

if (inequalityHolds) {
Expand All @@ -124,8 +124,7 @@ public static NodeInfo traverse_level(level_order_site level_site_data,

// Used by level-site and server-site to compare with a client
public static boolean compare(NodeInfo ld, int comparisonType,
Hashtable<String, BigIntegers> encrypted_features,
ObjectOutputStream toClient, alice Niu)
HashMap<String, BigIntegers> encrypted_features, alice Niu)
throws ClassNotFoundException, HomomorphicException, IOException {

long start_time = System.nanoTime();
Expand All @@ -138,16 +137,15 @@ public static boolean compare(NodeInfo ld, int comparisonType,
if ((comparisonType == 1) || (comparisonType == 2) || (comparisonType == 4)) {
encrypted_thresh = ld.getPaillier();
encrypted_client_value = encrypted_values.getIntegerValuePaillier();
toClient.writeInt(0);
Niu.writeInt(0);
Niu.setDGKMode(false);
}
else if ((comparisonType == 3) || (comparisonType == 5)) {
encrypted_thresh = ld.getDGK();
encrypted_client_value = encrypted_values.getIntegerValueDGK();
toClient.writeInt(1);
Niu.writeInt(1);
Niu.setDGKMode(true);
}
toClient.flush();
assert encrypted_client_value != null;
long stop_time = System.nanoTime();

Expand Down

0 comments on commit 43b147d

Please sign in to comment.