-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdevice.cu
168 lines (119 loc) · 4.2 KB
/
device.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
struct thread_info {
int num_sims; // How many simulations a thread should be doing
int thresh_aa;
int thresh_ab;
int thresh_bb; // implied to equal num_sims and be above 65536
/* 0 -> thresh_aa -> thresh_ab -> thresh_bb
* aa ab bb
* thresh_aa is aa, thresh_ab is ab, thresh_bb is bb, or it should be. I don't think this is yet
* the case.
*/
int scaledown_factor; // equal to 32 - (log(thresh_bb,2)). How much int should be scaled down
int * results; // [num_aa, num_ab] with num_bb being implicitly total - (num_aa+num_ab)
// Makes more sense in terms of limiting the number of memory accesses. Turns two memory acceses into 2
};
// Does it matter whether I pass a struct with arguments or series of arguments?
__global__ void setup_kernel(curandState *state, long long *seeds) {
int idx = threadIdx.x+blockIdx.x*THREADS_PER_BLOCK;
long long seed = seeds[idx];
curand_init(seed, idx, 0, &state[idx]); // &state[idx] != state+idx
}
__global__ void generate_kernel(curandState *curandstate, thread_info t_info){
int idx = threadIdx.x+blockIdx.x*THREADS_PER_BLOCK;
curandState localCurandState = curandstate[idx];
unsigned short num_aa = 0;
unsigned short num_ab = 0;
// num_bb is implicit and equal to (tinfo.num_sims-(num_aa+num_ab))
thread_info local_tinfo = t_info;
// IDK if this is good practice; the idea is to load the thread info locally instead of getting it from the main source
int num_one_parent = 0;
for (int i = 0; i < (local_tinfo.num_sims>>5); i++){
int rand_int = curand(&localCurandState);
num_one_parent += __popc(rand_int); // 64 bit version would be somewhat better, but whatevs
}
int num_two_parents = local_tinfo.num_sims - num_one_parent;
// t_info.results[idx] = num_one_parent;
// curandstate[idx] = localCurandState;
// local_tinfo.results[idx] = curand(&localCurandState) >> 16;
// return;
for (int i = 0; i < num_one_parent; i++) {
unsigned int rand_num = curand(&localCurandState) >> local_tinfo.scaledown_factor;
if (rand_num < local_tinfo.thresh_aa){
num_aa++;
}
else if (rand_num < local_tinfo.thresh_ab){
num_ab++;
}
// Implict else num_bb++;
}
for (int i = 0; i < num_two_parents; i++){
unsigned int p1_index = curand(&localCurandState) >> local_tinfo.scaledown_factor;
if (p1_index > local_tinfo.thresh_ab) {
// First bit b
unsigned int p2_index = curand(&localCurandState) >> local_tinfo.scaledown_factor;
if (p2_index < local_tinfo.thresh_aa) {
num_ab++;
}
else if (p2_index < local_tinfo.thresh_ab) {
if ((p2_index&1) == 0) {
num_ab++;
}
// else {num_bb++}
}
// else {num_bb++}
}
else if (p1_index > local_tinfo.thresh_aa) {
unsigned int p2_index = curand(&localCurandState) >> local_tinfo.scaledown_factor;
if ((p1_index&1) == 1){ // First bit B, same as above.
if (p2_index < local_tinfo.thresh_aa){
num_ab++;
}
else if (p2_index < local_tinfo.thresh_ab){
if ((p2_index&1) == 0) { // Second bit 0
num_ab++;
}
// num_bb++
}
// num_bb++
}
else { // First bit a
if (p2_index > local_tinfo.thresh_ab) { // Second bit B
num_ab++;
}
else if (p2_index > local_tinfo.thresh_aa){ // second bit AB
if ((p2_index&1) == 1) { // Second bit B
num_ab++;
}
else { // Second bit A
num_aa++;
}
}
else { // Second bit A
num_aa++;
}
}
}
else { // First bit A
unsigned int p2_index = curand(&localCurandState) >> local_tinfo.scaledown_factor;
if (p2_index > local_tinfo.thresh_ab){
num_ab++;
}
else if (p2_index > local_tinfo.thresh_aa){
if ((p2_index&1) == 1){
num_ab++;
}
else {
num_aa++;
}
}
else {
num_aa++;
}
}
}
curandstate[idx] = localCurandState;
int result = (int)(num_aa);
result = result << 16;
result += num_ab; // Try to make sure the compiler knows this is an int
local_tinfo.results[idx] = result;
}