Skip to content

Commit

Permalink
Merge pull request #5 from zhewenshen/david-section
Browse files Browse the repository at this point in the history
current progress
  • Loading branch information
aacaqq authored Jul 20, 2024
2 parents f33dedd + 0430cd8 commit 303e001
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 15 deletions.
13 changes: 11 additions & 2 deletions src/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import tenseal as ts
from cryptography.fernet import Fernet
from context_gen import TenSEALContext
from rich.console import Console
from rich.progress import track
Expand All @@ -15,6 +16,11 @@ def __init__(self):
self.model_inference_context = TenSEALContext.create_machine_learning_context()
self.test_data = {}
self.console = Console()
self.key = b""

def set_key(self, key: bytes):
self.key = key


def encrypt_data(self, data: List[float], context: ts.Context) -> str:
encrypted_data = ts.ckks_vector(context, data)
Expand Down Expand Up @@ -53,9 +59,12 @@ def send_request(self, server: 'Server', request: Dict[str, Any]) -> Dict[str, A
'x': [self.encrypt_data(x, context) for x in request['inference_data']['x']]
}

cur_key = Fernet(self.key)
serialized_request = serialize(request)
response = server.handle_request(serialized_request)
response_dict = deserialize(response)
cipher_text = cur_key.encrypt(serialized_request.encode('utf-8'))
# cipher_text = b"abcd" + cipher_text
response = server.handle_request(cipher_text)
response_dict = deserialize(cur_key.decrypt(response))

if 'result' in response_dict:
if isinstance(response_dict['result'], str):
Expand Down
11 changes: 11 additions & 0 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from sklearn.preprocessing import StandardScaler
from client import Client
from server import Server
from cryptography.fernet import Fernet
from utils import serialize, deserialize


def test_statistical_computations(client, server):
Expand Down Expand Up @@ -119,6 +121,15 @@ def test_machine_learning(client, server):
if __name__ == "__main__":
client = Client()
server = Server()
key = Fernet.generate_key();
client.set_key(key)
server.set_key(key)

# cur_key = Fernet(key)
# serialized_request = serialize("beep")
# cipher_text = cur_key.encrypt(serialized_request.encode('utf-8'))
# request = cur_key.decrypt(cipher_text)
# print(request)

print("Testing Statistical Computations:")
test_statistical_computations(client, server)
Expand Down
33 changes: 20 additions & 13 deletions src/server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Dict, List
from utils import serialize, deserialize
from model import EncryptedLR
from cryptography.fernet import Fernet
import tenseal as ts


Expand All @@ -12,37 +13,43 @@ def __init__(self):
self.data_counts: Dict[str, int] = {}
self.training_data: Dict[str, Dict[str, List[ts.CKKSVector]]] = {}
self.models: Dict[str, EncryptedLR] = {}
self.key = b""

def set_key(self, key: bytes):
self.key = key

def handle_request(self, request: str) -> str:
cur_key = Fernet(self.key)
request = cur_key.decrypt(request)
request_dict = deserialize(request)
context = ts.context_from(bytes.fromhex(request_dict['context']))

if request_dict['action'] == 'store':
return serialize(self.store_data(context, request_dict['key'], request_dict['data'], request_dict['size']))
return cur_key.encrypt(serialize(self.store_data(context, request_dict['key'], request_dict['data'], request_dict['size'])).encode('utf-8'))
elif request_dict['action'] == 'compute_average':
return serialize(self.compute_average(context, request_dict['key']))
return cur_key.encrypt(serialize(self.compute_average(context, request_dict['key'])).encode('utf-8'))
elif request_dict['action'] == 'compute_variance':
return serialize(self.compute_variance(context, request_dict['key']))
return cur_key.encrypt(serialize(self.compute_variance(context, request_dict['key'])).encode('utf-8'))
elif request_dict['action'] == 'sd':
return serialize(self.compute_standard_deviation(context, request_dict['key']))
return cur_key.encrypt(serialize(self.compute_standard_deviation(context, request_dict['key'])).encode('utf-8'))
elif request_dict['action'] == 'compute_overall_average':
return serialize(self.compute_overall_average(context, request_dict['keys']))
return cur_key.encrypt(serialize(self.compute_overall_average(context, request_dict['keys'])).encode('utf-8'))
elif request_dict['action'] == 'store_training_data':
return serialize(self.store_training_data(context, request_dict['key'], request_dict['training_data']))
return cur_key.encrypt(serialize(self.store_training_data(context, request_dict['key'], request_dict['training_data'])).encode('utf-8'))
elif request_dict['action'] == 'initialize_model':
return serialize(self.initialize_model(context, request_dict['key'], request_dict['n_features']))
return cur_key.encrypt(serialize(self.initialize_model(context, request_dict['key'], request_dict['n_features'])).encode('utf-8'))
elif request_dict['action'] == 'train_epoch':
return serialize(self.train_epoch(context, request_dict['key']))
return cur_key.encrypt(serialize(self.train_epoch(context, request_dict['key'])).encode('utf-8'))
elif request_dict['action'] == 'get_model_params':
return serialize(self.get_model_params(request_dict['key']))
return cur_key.encrypt(serialize(self.get_model_params(request_dict['key'])).encode('utf-8'))
elif request_dict['action'] == 'set_model_params':
return serialize(self.set_model_params(context, request_dict['key'], request_dict['params']))
return cur_key.encrypt(serialize(self.set_model_params(context, request_dict['key'], request_dict['params'])).encode('utf-8'))
elif request_dict['action'] == 'predict':
return serialize(self.predict(context, request_dict['key'], request_dict['inference_data']['x']))
return cur_key.encrypt(serialize(self.predict(context, request_dict['key'], request_dict['inference_data']['x'])).encode('utf-8'))
elif request_dict['action'] == 'predict_all':
return serialize(self.predict_all(context, request_dict['key'], request_dict['inference_data']['x']))
return cur_key.encrypt(serialize(self.predict_all(context, request_dict['key'], request_dict['inference_data']['x'])).encode('utf-8'))
else:
return serialize({'status': 'error', 'message': 'Invalid action'})
return cur_key.encrypt(serialize({'status': 'error', 'message': 'Invalid action'}).encode('utf-8'))

def store_data(self, context: ts.Context, key: str, data: str, size: int) -> Dict[str, str]:
if key in self.storage:
Expand Down

0 comments on commit 303e001

Please sign in to comment.