-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlwe_sugoi.py
125 lines (94 loc) · 3.17 KB
/
lwe_sugoi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# https://en.wikipedia.org/wiki/Learning_with_errors#Public-key_cryptosystem
# Oded Regev, “On lattices, learning with errors, random linear codes, and cryptography,” in Proceedings of the thirty-seventh annual ACM symposium on Theory of computing (Baltimore, MD, USA: ACM, 2005), 84–93
import math
import random
from functools import reduce
from mpmath import jtheta, quad, mp
def prime(a):
flags = [True] * (a + 1)
for i in range(2, int(math.sqrt(a) + 0.1) + 1):
if not flags[i]:
continue
for j in range(i*i, a + 1, i):
flags[j] = False
lst = []
for i in range(2, a + 1):
if flags[i]:
lst.append(i)
return lst
def dot(x, y, p):
return sum(map(lambda x: x[0]*x[1], zip(x, y))) % p
def add(x, y, p):
return [(x[i] + y[i]) % p for i in range(len(x))]
def add_vectors(matrix, p):
return reduce(lambda x, y: add(x, y, p), matrix, [0] * len(matrix[0]))
### consts
def genconsts(n = 128):
p = random.choice(list(filter(lambda x: x >= n*n, prime(2*n*n))))
#eps = (3.45678910111213141516)
#m = (1 + eps)*(n + 1)*math.log(p)
m = 5*n
alp = lambda n: 1/(math.sqrt(n)*math.log(n))
theta = lambda r: jtheta(3, -math.pi*r, math.exp(-alp(n)*alp(n)*math.pi))
theta_quad = lambda i: quad(theta, [(i-0.5)/p, (i+0.5)/p])
print("Generating distribution....")
x_dist = []
i = 0
while True:
if i % (p // 100) == 0:
print("%s/%s" % (i, p - 1))
quad_result = theta_quad(i)
x_dist.append(quad_result)
if quad_result < 1e-10:
break
i += 1
print("%s/%s" % (p - 1, p - 1))
x_dist = x_dist + [0]*(p - 2*len(x_dist) + 1) + list(reversed(x_dist[1::]))
### priv key
s = [random.randint(0, p - 1) for i in range(n)]
### public key
a = [[random.randint(0, p - 1) for i in range(n)] for i in range(m)]
e = random.choices(list(range(p)), weights=x_dist, k=m)
b = [(dot(a[i], s, p) + e[i]) % p for i in range(m)]
return (m, a, b, p, s)
### encryption
def encryption(bit, m, a, b, p):
S = list(filter(lambda x: random.randint(0, 1) == 0, range(m)))
if bit == 0:
enc_a = add_vectors([a[i] for i in S], p)
enc_b = sum([b[i] for i in S]) % p
else:
enc_a = add_vectors([a[i] for i in S], p)
enc_b = ((p // 2) + sum([b[i] for i in S])) % p
enc = (enc_a, enc_b)
return enc
### decryption
def decryption(enc, s, p):
value = (enc[1] - dot(enc[0], s, p)) % p
distance_zero = min(value, p - value)
distance_half = abs(value - p//2)
if distance_zero < distance_half:
dec = 0
else:
dec = 1
return dec
### some test ...
(m, a, b, p, s) = genconsts()
print("key length: %s" % m)
print("public key a: %s" % a)
print("public key b: %s" % b)
print("modulo: %s" % p)
print("secret key: %s" % s)
samples = 1000
ok = 0
error = 0
for i in range(samples):
bit = random.randint(0, 1)
enc = encryption(bit, m, a, b, p)
dec = decryption(enc, s, p)
if bit == dec:
ok += 1
else:
error += 1
print("ok: %s/%s" % (ok, samples))
print("error: %s/%s" % (error, samples))