-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathxi.py
134 lines (116 loc) · 4.31 KB
/
xi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import torch
# CW2B and EH3
def hash31(a, b, x):
""" Adapted from MassDal: http://www.cs.rutgers.edu/~muthu/massdal-code-index.html
Computes Carter-Wegman (CW) hash with Mersenne trick*/
"""
res = a * x + b
return ((res >> 31) + res) & 2147483647
def seq_xor(x):
""" Computes parity bit of the bits of an integer*/
"""
x ^= (x >> 16)
x ^= (x >> 8)
x ^= (x >> 4)
x ^= (x >> 2)
x ^= (x >> 1)
return (x & 1)
def EH3(i0, I1, j):
""" +-1 random variables, 3-wise independent schemes
"""
mask = 0xAAAAAAAA
p_res = (I1 & j) ^ (j & (j<<1) & mask)
return torch.where(((i0 ^ seq_xor(p_res)) & (1 == 1)) != 0, 1, -1)
# return torch.where(((i0 ^ seq_xor(p_res)) & 1 == 1), 1, -1)
def CW2B(a, b, x, M):
"""b-valued random variables 2-wise CW scheme
"""
p_res = hash31(a, b, x)
res = p_res % M;
return res;
class B_Xi(object):
def __init__(self, B, I1=2**32-1, I2=2**32-1, preprocess=True):
super(B_Xi, self).__init__()
""" hash to B buckets
"""
self.num_buckets = B
seeds = torch.tensor([I1, I2], dtype=torch.int64)
if preprocess:
k_mask = 0xffffffff
args = torch.tensor([I1, I2], dtype=torch.int64)
seeds[0] = ((args[0] << 16)^(args[1] & 0x0000ffff)) & k_mask
args[0] = (36969*(args[0] & 0x0000ffff)) + ((args[0])>>16)
args[1] = (18000*(args[1] & 0x0000ffff)) + ((args[1])>>16)
seeds[1] = ((args[0] << 16)^(args[1] & 0x0000ffff)) & k_mask
self.seeds = seeds
def element(self, j):
return CW2B(*(self.seeds), j, self.num_buckets)
def __call__(self, j):
return self.element(j)
def __str__(self):
return "{}-wise xi({}, {})".format(self.num_buckets, *(self.seeds))
def __repr__(self):
return str(self)
class Xi(object):
def __init__(self, I1=2**32-1, I2=2**32-1, preprocess=True):
super(Xi, self).__init__()
""" hash to pos or neg 1
"""
seeds = torch.tensor([I1, I2], dtype=torch.int64)
if preprocess:
k_mask = 0xffffffff
args = torch.tensor([I1, I2], dtype=torch.int64)
seeds[0] = ((args[0] << 16)^(args[1] & 0x0000ffff)) & k_mask
args[0] = (36969*(args[0] & 0x0000ffff)) + ((args[0])>>16)
args[1] = (18000*(args[1] & 0x0000ffff)) + ((args[1])>>16)
seeds[1] = ((args[0] << 16)^(args[1] & 0x0000ffff)) & k_mask
self.seeds = seeds
def element(self, j):
return EH3(*(self.seeds), j)
def __call__(self, j):
return self.element(j)
def __str__(self):
return "+-1 xi({}, {})".format(*(self.seeds))
def __repr__(self):
return str(self)
if __name__ == '__main__':
import random
random.seed(2 ** 31 - 1)
interval = 100000
print("seeding test...")
for _ in range(10000):
random.randint(0, 2**31-1)
xi = Xi(random.randint(1, 2**32-1), random.randint(1, 2**32-1))
count = 0
for i in range(interval):
res = xi(i)
assert res in (-1, 1), (res, xi)
count += res
print("{} sum over range({}): {}".format(xi, interval, count.item()))
b_xi = B_Xi(100, random.randint(1, 2**32-1), random.randint(1, 2**32-1))
total = 0
for i in range(interval):
res = b_xi(i)
assert res >= 0, (res, i)
total += res + 1
print("{} average {} / {} = {}".format(b_xi, total, interval, total/interval))
b_xi_1 = B_Xi(10, random.randint(1, 2**32-1), random.randint(1, 2**32-1))
b_xi_2 = B_Xi(10, random.randint(1, 2**32-1), random.randint(1, 2**32-1))
hits = 0
for i in range(interval):
hits += (b_xi_1(i) == b_xi_2(i))
print("({}, {}) hit {} / {} = {} rate".format(b_xi_1, b_xi_2, hits, interval, hits/interval))
print("short interval test:")
x = Xi(1234567, 9876543)
print(x)
for i in range(20):
print(x(i))
b = B_Xi(100, 1234567, 9876543)
print(b)
for i in range(20):
print(b(i))
print("sketch link_type test:")
x = Xi(1675206430, 3737435780)
b = B_Xi(50, 1664175982, 431896386)
for i in range(1, 19):
print("bucket {:>5}: {:>10}".format(b(i).item(), x(i).item()))