-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathc_jax.qmd
395 lines (307 loc) · 12.1 KB
/
c_jax.qmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
---
title: "Du C++ depuis JAX"
format: html
toc: true
author:
- Hugo Gangloff
date: 21/08/2024
---
# FinistR : bootcamp à Roscoff
## Objectifs
Dans ce tutoriel, qui constitue le suite de [Du C++ depuis
Python](https://stateofther.github.io/finistR2024/c_python.html), nous étudions
le module `jax.extend.ffi` dont un tutoriel est donné
[ici](https://jax.readthedocs.io/en/latest/ffi.html). Plus précisément, nous
allons adapter ce dernier pour l'utilisation d'un algorithme naïf
d'échantillonneur de Gibbs écrit en C++ depuis JAX.
## Le projet
Tous les fichiers du projet sont disponibles sur [ce dépôt
Github](https://github.com/HGangloff/jax_ffi_experiments). Les outils et
l'arborescence sont similaires au tutoriel *C++ depuis Python*.
## Code C++
Nous avons à disposition une mini bibliothèque C++ implémentant une version naïve
de l'échantillonneur de Gibbs. Les trois fonctions qu'elle offre sont
disponibles dans `lib/gibbs_sampler.h` :
```Cpp
void initialize(int* image, int rows, int cols, int Q);
void printImage(int* image, int rows, int cols);
void RunGibbsSampler(int* image, int rows, int cols, int Q, float beta, int iter);
```
Le fichier `gibbs_sampler_cpp.cpp` chronomètre cette implémentation de
l'échantillonneur de Gibbs pour $1000$ itérations, sur une image de
$200\times200$, pour un modèle de Potts à $3$ classes où $\beta=1.0$ :
```Cpp
// Thanks chatgpt for the quick gibbs sampler implementation
#include <iostream>
#include <cmath>
#include <cstdlib>
#include <ctime>
#include <chrono>
#include "lib/gibbs_sampler.h"
using namespace std;
using namespace std::chrono; // Use the chrono namespace
// Main function
int main() {
int rows = 200;
int cols = 200;
int Q = 3; // Number of possible labels (states)
float beta = 1.0; // Coupling strength
int iter = 1000;
// Allocate memory for the image
int* image = new int[rows * cols];
auto start = high_resolution_clock::now();
// Initialize the image with random states
initialize(image, rows, cols, Q);
// cout << "Initial image:" << endl;
// printImage(image, rows, cols);
RunGibbsSampler(image, rows, cols, Q, beta, iter);
auto stop = high_resolution_clock::now();
auto duration = duration_cast<milliseconds>(stop - start);
cout << "Execution time: " << duration.count() << " milliseconds" << endl;
// cout << "Final image after Gibbs sampling:" << endl;
// printImage(image, rows, cols);
// Free allocated memory
delete[] image;
return 0;
}
```
Le fichier ci-dessus est compilé et exécuté avec la commande :
```bash
bazel build main && bazel-bin/main
```
L'instruction `build main` se trouve dans `BUILD.bazel` (il s'agit d'une simple
compilation d'exécutable `cc_binary`) :
```bash
cc_binary( # no header here, we are not building a library
name = "main",
srcs = ["gibbs_sampler_cpp.cpp"],
visibility = ["//visibility:public"],
deps = ["//lib:gibbs_sampler_lib"]
)
```
Mais on remarque que cette règle de compilation dépend d'une autre règle de
compilation qui est celle de la bibliothèque de l'échantillonneur de Gibbs
`gibbs_sampler_lib`. Cette règle, exécutée automatiquement par dépendance, se situe
dans `lib/BUILD.bazel` :
```bash
cc_library(
name = "gibbs_sampler_lib",
srcs = ["gibbs_sampler.cpp"],
hdrs = [
"gibbs_sampler.h", # These headers are intended to be included by other libraries or binaries that depend on this library.
],
copts = ["-std=c++17"],
visibility = ["//visibility:public"],
)
```
Finalement, l'exécution du programme donne :
```bash
Execution time: 3718 milliseconds
```
Retenons cette valeur à laquelle nous allons nous comparer plus tard. Retenons
également l'illustration (voir le dépôt git pour le code C++ incluant la
sauvegarde de l'image) :
![Résultat de `gibbs_sampler_cpp.cpp`](pure_cpp.png)
## C++ depuis JAX
### Wrapping dans XLA
Comme indiqué dans les deux seuls tuto sur le sujet actuellement disponibles
(sur la [doc de JAX](https://jax.readthedocs.io/en/latest/ffi.html) et sur la
[doc XLA](https://openxla.org/xla/custom_call)), il va nous falloir wrapper
les fonctions de notre bibliothèque dans l'API des *custom call* de
XLA.
Nous commençons par télécharger l'API en trois fichiers, `api.h`, `c_api.h`
et `ffi.h`, que nous plaçons dans `lib/xla/ffi/api/`. Nous créons ensuite une
fonction `GibbsSamplerImpl` qui va envelopper `initialize` et `RunGibbsSampler`
disponibles dans `gibbs_sampler.h`. Ce wrapping s'effectue à l'aide d'objets
particuliers : `Buffer` (représentant les `jax.numpy` arrays), `Datatype`,
`Error`, `Bind` et `Attr`. Ces objets sont à l'intermédiaire des structures de
données de `jax.numpy` et de C++.
Leur utilisation est décrite dans les tuto cités plus hauts.
On note de manière intéressante que `GibbsSamplerImpl` ne prend pas de tableau
(`jnp.array`) en entrée, i.e., pas de `Bind.Arg<>()` ; les seuls arguments
d'entrée sont des *attributs nommés*, i.e., `Bind().Attr<>()`. Il y aura
cependant un retour qui sera le tableau du résultat de l'échantillonneur, i.e.,
`Bind().Ret<>()`. Finalement, la vraie fonction d'intérêt qui pourra être utilisable via
le module `jax.extend` est construite par la macro
`XLA_FFI_DEFINE_HANDLER_SYMBOL` et s'appelera `GibbsSampler`.
Toutes ces opérations sont écrites dans
`lib/gibbs_sampler_xla.cpp` :
```cpp
#include <cmath>
#include <cstdint>
#include <numeric>
#include <type_traits>
#include <utility>
#include <functional>
#include "xla/ffi/api/c_api.h"
#include "xla/ffi/api/ffi.h"
extern void initialize(int* image, int rows, int cols, int Q);
extern void RunGibbsSampler(int* image, int rows, int cols, int Q, float beta, int iter);
namespace ffi = xla::ffi;
ffi::Error GibbsSamplerImpl(
//ffi::BufferR1<ffi::DataType::S32> in_img,
ffi::Result<ffi::BufferR2<ffi::DataType::S32>> img,
int rows,
int cols,
int Q,
float beta,
int iter
) {
initialize(img->typed_data(), rows, cols, Q);
RunGibbsSampler(img->typed_data(), rows, cols, Q, beta, iter);
return ffi::Error::Success();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(
GibbsSampler, GibbsSamplerImpl,
ffi::Ffi::Bind()
// .Arg<ffi::BufferR1<ffi::DataType::S32>>()
.Ret<ffi::BufferR2<ffi::DataType::S32>>()
.Attr<int>("rows")
.Attr<int>("cols")
.Attr<int>("Q")
.Attr<float>("beta")
.Attr<int>("iter")
);
```
Afin d'alléger le code, le wrapping de la bibliothèque `gibbs_sampler.h` à l'aide
de l'API XLA, constitue une bibliothèque intermédiaire dont le header est
`gibbs_sampler_xla.h`. On la compile avec la règle `gibbs_sampler_xla_lib`
définie dans `lib/BUILD.bazel` :
```bash
cc_library(
name = "gibbs_sampler_xla_lib",
srcs = [
"xla/ffi/api/c_api.h",
"xla/ffi/api/ffi.h",
"xla/ffi/api/api.h",
"gibbs_sampler_xla.cpp",
], # When deciding whether to put a header into hdrs or srcs, you should ask whether you want consumers of this library to be able to directly include it
hdrs = [
"gibbs_sampler_xla.h",
],
includes = [ # need to use include -I option is not available due to sandbox stuff
".", # adding to srcs or hdrs does not include
],
copts = ["-std=c++17 -rdynamic"], # -rdynamic to use extern
deps = ["//lib:gibbs_sampler_lib"],
visibility = ["//visibility:public"],
)
```
On note que cette règle nécessite l'exécution de la règle `gibbs_sampler_lib`.
### Création du module Python
Comme dans le tuto *C++ depuis Python* nous allons utiliser `pybind` et
les `PyCapsule` pour cette partie des opérations. Nous créons alors un module
python appelé `gibbs_sampler` contiendra une fonction elle-même appelée
`gibbs_sampler`. L'appel de cette dernière retournera une `PyCapsule` contenant
la fonction `GibbsSampler` définie à la fin de la section précédente.
Le code effectuant tout cela est dans `lib/gibbs_sampler_xla_pybind.cc`.
```cpp
#include <pybind11/pybind11.h>
#include "gibbs_sampler_xla.h"
template <typename T>
pybind11::capsule EncapsulateFunction(T* fn) {
return pybind11::capsule(reinterpret_cast<void*>(fn));
}
PYBIND11_MODULE(gibbs_sampler, m) { // please match the pybind_extension target name
m.def("gibbs_sampler", []() { return EncapsulateFunction(GibbsSampler); });
}
```
Le code se compile en un module python avec la règle `gibbs_sampler` définie dans
`lib/BUILD.bazel` :
```bash
pybind_extension( # must be in two steps (first cc_library then pybind_extension)
name = "gibbs_sampler", # must match the pybind module name
srcs = ["gibbs_sampler_xla_pybind.cc"],
deps = ["//lib:gibbs_sampler_xla_lib"],
copts = ["-std=c++17 -rdynamic"]
)
```
On note la dépendance sur la règle `gibbs_sampler_xla_lib`.
Finalement le module Python peut se compiler avec la simple commande
```bash
bazel build //lib:gibbs_sampler
```
### Utilisation du module Python avec `jax.extend.ffi`
Nous sommes enfin en mesure d'utiliser la fonction C++ depuis JAX !
Dans le fichier `gibbs_sampler_c_from_JAX.py` :
1) Nous enregistrons notre appel FFI (nom, PyCapsule, platform) :
```bash
jex.ffi.register_ffi_target("gibbs_sampler", gibbs_sampler_lib.gibbs_sampler(), platform="cpu")
```
2) Nous définissons une fonction Python `gibbs_sampler_cpp` qui effectue
l'appel à `GibbsSampler`, la fonction cachée dans la PyCapsule, à l'aide de
`jax.ffi.ffi_call`. On note que cet appel est maintenant plutôt simple.
Il faut préciser les dimensions et le
type du résultat retourné avec un object `jax.ShapeDtypeStruct`, il ne nous
faut pas préciser d'argument en entrée car nous n'avons que des attributs qui
sont passés en *keyword argument*.
3) Nous instruisons l'appel avec les mêmes paramètres qu'auparavant et le chronométrons :
```python
rows = 200
cols = 200
Q = 3
beta = 1
n_iter = 1000
start = time.time()
res = gibbs_sampler_cpp(rows, cols, Q, beta, n_iter)
end = time.time()
print(f"Time: {end - start} seconds")
```
L'exécution de ce script Python donne :
```bash
Time: 3.9689433574676514 seconds
```
Ce qui est un temps très légèrement supérieur à l'appel directement depuis C++
à la même fonction avec les mêmes paramètres. Nous obtenons donc le résultat
espéré et terme de temps, ainsi qu'en illustration :
![Résultat de `gibbs_sampler_c_from_JAX.py`](c_from_JAX.png)
## Comparaison avec Python / JAX
La bibliothèque JAX de Python offre la fonctionnalité clé de compilation à la
volée (*Just In Time compilation*) qui permet une compilation du code
Python à la première exécution pour des exécutions ultérieures bien plus
rapides, **en général**.
### Sur CPU
La bibliothèque [mrfx](https://github.com/HGangloff/mrfx) offre une implémentation
en pur JAX de l'échantillonneur de Gibbs et de l'échantillonneur de Gibbs
chromatique. Nous pouvons obtenir un chronométrage de l'échantillonneur de
Gibbs grâce au code du fichier `gibbs_sampler_JAX_cpu.py` :
```python
import os
os.environ["JAX_PLATFORMS"] = "cpu"
import time
import jax
key = jax.random.PRNGKey(0)
from mrfx.models import Potts
from mrfx.samplers import GibbsSampler
from mrfx.experiments import time_complete_sampling
rows = 200
cols = 200
Q = 3
beta = 1
n_iter = 1000
key, subkey = jax.random.split(key, 2)
times, n_iterations = time_complete_sampling(
GibbsSampler,
Potts,
subkey,
[Q],
[(rows, cols)],
5,
kwargs_sampler={"eps": 0.01, "max_iter":n_iter, "cv_type":"iter_only"},
kwargs_model={"beta":1.}
)
print("JAX on CPU", times)
```
On note un temps de compilation et d'exécution affiché de $7.7714$ secondes et un temps
d'exécution seule de $7.3613$.
### Sur GPU
Le fichier `gibbs_sampler_JAX_gpu.py` permet d'excécuter le code précédent sur
GPU. En revanche, il apparaît que l'algorithme est environ 10 fois plus lent à
la compilation et à l'exécution. C'est probablement dû à l'aspect hautement
itératif de l'algorithme étudié qui s'adapte mal à une architecture GPU...
## Conclusion
Nous avons avec succès appelé du code C++ depuis JAX.
Nous pouvons conclure que pour cette algorithme d'échantillonnage de Gibbs
hautement itératif, la compilation à la volée de JAX ne surpasse pas un appel à
du code C++. Le gain à l'utilisation de `jax.extend.ffi` est ici très
important.