From d2c77124544fe9a38378d9ad944e8bd6398a86f6 Mon Sep 17 00:00:00 2001 From: Vivek Pandya Date: Mon, 5 Feb 2024 17:43:32 +0530 Subject: [PATCH 1/3] Initial commit --- Cargo.lock | 216 ++++++++++++++++------- Cargo.toml | 6 +- circuits/Cargo.toml | 5 +- circuits/src/cpu/add.rs | 5 + circuits/src/lib.rs | 1 + circuits/src/stark/prover.rs | 121 ++++++++++++- circuits/src/stark/recursive_verifier.rs | 10 +- circuits/src/stark/serde.rs | 6 +- circuits/src/stark/verifier.rs | 9 +- circuits/src/test_utils.rs | 76 ++++++++ cli/Cargo.toml | 3 + cli/src/main.rs | 48 ++++- 12 files changed, 433 insertions(+), 73 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 705f72856..265685e0a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10,9 +10,9 @@ checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" [[package]] name = "ahash" -version = "0.8.9" +version = "0.8.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d713b3834d76b85304d4d525563c1276e2e30dc97cc67bfb4585a4a29fc2c89f" +checksum = "8b79b82693f705137f8fb9b37871d99e4f9a7df12b917eed79c3d3954830a60b" dependencies = [ "cfg-if", "const-random", @@ -53,9 +53,9 @@ checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" [[package]] name = "anstream" -version = "0.6.12" +version = "0.6.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96b09b5178381e0874812a9b157f7fe84982617e48f71f4e3235482775e5b540" +checksum = "d96bd03f33fe50a863e394ee9718a706f988b9079b20c3784fb726e7678b62fb" dependencies = [ "anstyle", "anstyle-parse", @@ -214,7 +214,7 @@ dependencies = [ "iana-time-zone", "num-traits", "serde", - "windows-targets 0.52.3", + "windows-targets 0.52.4", ] [[package]] @@ -288,7 +288,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.51", + "syn 2.0.52", ] [[package]] @@ -419,9 +419,9 @@ dependencies = [ [[package]] name = "crossbeam-channel" -version = "0.5.11" +version = "0.5.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "176dc175b78f56c0f321911d9c8eb2b77a78a4860b9c19db83835fea1a46649b" +checksum = "ab3db02a9c5b5121e1e42fbdb1aeb65f5e02624cc58c43f2884c6ccac0b82f95" dependencies = [ "crossbeam-utils", ] @@ -467,6 +467,24 @@ dependencies = [ "typenum", ] +[[package]] +name = "cuda-config" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ee74643f7430213a1a78320f88649de309b20b80818325575e393f848f79f5d" +dependencies = [ + "glob", +] + +[[package]] +name = "cuda-driver-sys" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d4c552cc0de854877d80bcd1f11db75d42be32962d72a6799b88dcca88fffbd" +dependencies = [ + "cuda-config", +] + [[package]] name = "darling" version = "0.20.8" @@ -488,7 +506,7 @@ dependencies = [ "proc-macro2", "quote", "strsim 0.10.0", - "syn 2.0.51", + "syn 2.0.52", ] [[package]] @@ -499,7 +517,7 @@ checksum = "a668eda54683121533a393014d8692171709ff57a7d61f187b6e782719f8933f" dependencies = [ "darling_core", "quote", - "syn 2.0.51", + "syn 2.0.52", ] [[package]] @@ -652,6 +670,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "glob" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" + [[package]] name = "half" version = "2.4.0" @@ -687,9 +711,9 @@ checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" [[package]] name = "hermit-abi" -version = "0.3.8" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "379dada1584ad501b383485dd706b8afb7a70fcbc7f4da7d780638a5a6124a60" +checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" [[package]] name = "hex" @@ -697,6 +721,15 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +[[package]] +name = "home" +version = "0.5.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5" +dependencies = [ + "windows-sys 0.52.0", +] + [[package]] name = "humantime" version = "2.1.0" @@ -770,9 +803,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.2.3" +version = "2.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "233cf39063f058ea2caae4091bf4a3ef70a653afbc026f5c4a4135d114e3c177" +checksum = "7b0b929d511467233429c45a44ac1dcaa21ba0f5ba11e4879e6ed28ddb4f9df4" dependencies = [ "equivalent", "hashbrown 0.14.3", @@ -893,9 +926,9 @@ checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c" [[package]] name = "log" -version = "0.4.20" +version = "0.4.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" +checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" [[package]] name = "memchr" @@ -951,6 +984,8 @@ dependencies = [ "plonky2", "proptest", "rayon", + "rustacuda", + "rustacuda_core", "serde", "starky", "thiserror", @@ -965,7 +1000,7 @@ dependencies = [ "proc-macro-error", "proc-macro2", "quote", - "syn 2.0.51", + "syn 2.0.52", ] [[package]] @@ -1169,7 +1204,7 @@ dependencies = [ [[package]] name = "plonky2" version = "0.2.0" -source = "git+https://github.com/0xmozak/plonky2.git#7a3b7d487aa23ca43d43b513b61cf0387b414015" +source = "git+https://github.com/0xmozak/plonky2.git?branch=vivek/cuda#0df2cdeed87d649637b4f719d0e1d47d5499f2b5" dependencies = [ "ahash", "anyhow", @@ -1179,17 +1214,29 @@ dependencies = [ "keccak-hash", "log", "num", + "plonky2-cuda", "plonky2_field", - "plonky2_maybe_rayon 0.2.0 (git+https://github.com/0xmozak/plonky2.git)", + "plonky2_maybe_rayon 0.2.0 (git+https://github.com/0xmozak/plonky2.git?branch=vivek/cuda)", "plonky2_util", "rand", "rand_chacha", + "rustacuda", + "rustacuda_core", "serde", "static_assertions", "unroll", "web-time", ] +[[package]] +name = "plonky2-cuda" +version = "0.1.0" +source = "git+https://github.com/0xmozak/plonky2.git?branch=vivek/cuda#0df2cdeed87d649637b4f719d0e1d47d5499f2b5" +dependencies = [ + "cc", + "which", +] + [[package]] name = "plonky2_crypto" version = "0.1.0" @@ -1210,13 +1257,15 @@ dependencies = [ [[package]] name = "plonky2_field" version = "0.2.0" -source = "git+https://github.com/0xmozak/plonky2.git#7a3b7d487aa23ca43d43b513b61cf0387b414015" +source = "git+https://github.com/0xmozak/plonky2.git?branch=vivek/cuda#0df2cdeed87d649637b4f719d0e1d47d5499f2b5" dependencies = [ "anyhow", "itertools 0.12.1", "num", "plonky2_util", "rand", + "rustacuda", + "rustacuda_core", "serde", "static_assertions", "unroll", @@ -1231,7 +1280,7 @@ checksum = "92ff44a90aaca13e10e7ddf8fab815ba1b404c3f7c3ca82aaf11c46beabaa923" [[package]] name = "plonky2_maybe_rayon" version = "0.2.0" -source = "git+https://github.com/0xmozak/plonky2.git#7a3b7d487aa23ca43d43b513b61cf0387b414015" +source = "git+https://github.com/0xmozak/plonky2.git?branch=vivek/cuda#0df2cdeed87d649637b4f719d0e1d47d5499f2b5" dependencies = [ "rayon", ] @@ -1239,7 +1288,7 @@ dependencies = [ [[package]] name = "plonky2_util" version = "0.2.0" -source = "git+https://github.com/0xmozak/plonky2.git#7a3b7d487aa23ca43d43b513b61cf0387b414015" +source = "git+https://github.com/0xmozak/plonky2.git?branch=vivek/cuda#0df2cdeed87d649637b4f719d0e1d47d5499f2b5" [[package]] name = "plotters" @@ -1419,9 +1468,9 @@ dependencies = [ [[package]] name = "rayon" -version = "1.8.1" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa7237101a77a10773db45d62004a272517633fbcc3df19d96455ede1122e051" +checksum = "e4963ed1bc86e4f3ee217022bd855b297cef07fb9eac5dfa1f788b220b49b3bd" dependencies = [ "either", "rayon-core", @@ -1481,6 +1530,35 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "rustacuda" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47208516ab5338b592d63560e90eaef405d0ec880347eaf7742d893b0a31e228" +dependencies = [ + "bitflags 1.3.2", + "cuda-driver-sys", + "rustacuda_core", + "rustacuda_derive", +] + +[[package]] +name = "rustacuda_core" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3858b08976dc2f860c5efbbb48cdcb0d4fafca92a6ac0898465af16c0dbe848" + +[[package]] +name = "rustacuda_derive" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43ce8670a1a1d0fc2514a3b846dacdb65646f9bd494b6674cfacbb4ce430bd7e" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "rustc_version" version = "0.4.0" @@ -1584,7 +1662,7 @@ checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.51", + "syn 2.0.52", ] [[package]] @@ -1608,7 +1686,7 @@ dependencies = [ "chrono", "hex", "indexmap 1.9.3", - "indexmap 2.2.3", + "indexmap 2.2.5", "serde", "serde_derive", "serde_json", @@ -1625,7 +1703,7 @@ dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.51", + "syn 2.0.52", ] [[package]] @@ -1688,7 +1766,7 @@ checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" [[package]] name = "starky" version = "0.2.0" -source = "git+https://github.com/0xmozak/plonky2.git#7a3b7d487aa23ca43d43b513b61cf0387b414015" +source = "git+https://github.com/0xmozak/plonky2.git?branch=vivek/cuda#0df2cdeed87d649637b4f719d0e1d47d5499f2b5" dependencies = [ "ahash", "anyhow", @@ -1697,7 +1775,7 @@ dependencies = [ "log", "num-bigint", "plonky2", - "plonky2_maybe_rayon 0.2.0 (git+https://github.com/0xmozak/plonky2.git)", + "plonky2_maybe_rayon 0.2.0 (git+https://github.com/0xmozak/plonky2.git?branch=vivek/cuda)", "plonky2_util", ] @@ -1745,9 +1823,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.51" +version = "2.0.52" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ab617d94515e94ae53b8406c628598680aa0c9587474ecbe58188f7b345d66c" +checksum = "b699d15b36d1f02c3e7c69f8ffef53de37aefae075d8488d4ba1a7788d574a07" dependencies = [ "proc-macro2", "quote", @@ -1803,7 +1881,7 @@ dependencies = [ "cfg-if", "proc-macro2", "quote", - "syn 2.0.51", + "syn 2.0.52", ] [[package]] @@ -1814,7 +1892,7 @@ checksum = "5c89e72a01ed4c579669add59014b9a524d609c0c88c6a585ce37485879f6ffb" dependencies = [ "proc-macro2", "quote", - "syn 2.0.51", + "syn 2.0.52", "test-case-core", ] @@ -1835,7 +1913,7 @@ checksum = "a953cb265bef375dae3de6663da4d3804eee9682ea80d8e2542529b73c531c81" dependencies = [ "proc-macro2", "quote", - "syn 2.0.51", + "syn 2.0.52", ] [[package]] @@ -1915,7 +1993,7 @@ version = "0.19.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b5bb770da30e5cbfde35a2d7b9b8a2c4b8ef89548a7a6aeab5c9a576e3e7421" dependencies = [ - "indexmap 2.2.3", + "indexmap 2.2.5", "toml_datetime", "winnow", ] @@ -2088,7 +2166,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.51", + "syn 2.0.52", "wasm-bindgen-shared", ] @@ -2110,7 +2188,7 @@ checksum = "642f325be6301eb8107a83d12a8ac6c1e1c54345a7ef1a9261962dfefda09e66" dependencies = [ "proc-macro2", "quote", - "syn 2.0.51", + "syn 2.0.52", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -2161,6 +2239,18 @@ dependencies = [ "rustls-pki-types", ] +[[package]] +name = "which" +version = "4.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7" +dependencies = [ + "either", + "home", + "once_cell", + "rustix", +] + [[package]] name = "winapi" version = "0.3.9" @@ -2198,7 +2288,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" dependencies = [ - "windows-targets 0.52.3", + "windows-targets 0.52.4", ] [[package]] @@ -2231,7 +2321,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-targets 0.52.3", + "windows-targets 0.52.4", ] [[package]] @@ -2251,17 +2341,17 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.52.3" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d380ba1dc7187569a8a9e91ed34b8ccfc33123bbacb8c0aed2d1ad7f3ef2dc5f" +checksum = "7dd37b7e5ab9018759f893a1952c9420d060016fc19a472b4bb20d1bdd694d1b" dependencies = [ - "windows_aarch64_gnullvm 0.52.3", - "windows_aarch64_msvc 0.52.3", - "windows_i686_gnu 0.52.3", - "windows_i686_msvc 0.52.3", - "windows_x86_64_gnu 0.52.3", - "windows_x86_64_gnullvm 0.52.3", - "windows_x86_64_msvc 0.52.3", + "windows_aarch64_gnullvm 0.52.4", + "windows_aarch64_msvc 0.52.4", + "windows_i686_gnu 0.52.4", + "windows_i686_msvc 0.52.4", + "windows_x86_64_gnu 0.52.4", + "windows_x86_64_gnullvm 0.52.4", + "windows_x86_64_msvc 0.52.4", ] [[package]] @@ -2278,9 +2368,9 @@ checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" [[package]] name = "windows_aarch64_gnullvm" -version = "0.52.3" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68e5dcfb9413f53afd9c8f86e56a7b4d86d9a2fa26090ea2dc9e40fba56c6ec6" +checksum = "bcf46cf4c365c6f2d1cc93ce535f2c8b244591df96ceee75d8e83deb70a9cac9" [[package]] name = "windows_aarch64_msvc" @@ -2296,9 +2386,9 @@ checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" [[package]] name = "windows_aarch64_msvc" -version = "0.52.3" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8dab469ebbc45798319e69eebf92308e541ce46760b49b18c6b3fe5e8965b30f" +checksum = "da9f259dd3bcf6990b55bffd094c4f7235817ba4ceebde8e6d11cd0c5633b675" [[package]] name = "windows_i686_gnu" @@ -2314,9 +2404,9 @@ checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" [[package]] name = "windows_i686_gnu" -version = "0.52.3" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a4e9b6a7cac734a8b4138a4e1044eac3404d8326b6c0f939276560687a033fb" +checksum = "b474d8268f99e0995f25b9f095bc7434632601028cf86590aea5c8a5cb7801d3" [[package]] name = "windows_i686_msvc" @@ -2332,9 +2422,9 @@ checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" [[package]] name = "windows_i686_msvc" -version = "0.52.3" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28b0ec9c422ca95ff34a78755cfa6ad4a51371da2a5ace67500cf7ca5f232c58" +checksum = "1515e9a29e5bed743cb4415a9ecf5dfca648ce85ee42e15873c3cd8610ff8e02" [[package]] name = "windows_x86_64_gnu" @@ -2350,9 +2440,9 @@ checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" [[package]] name = "windows_x86_64_gnu" -version = "0.52.3" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "704131571ba93e89d7cd43482277d6632589b18ecf4468f591fbae0a8b101614" +checksum = "5eee091590e89cc02ad514ffe3ead9eb6b660aedca2183455434b93546371a03" [[package]] name = "windows_x86_64_gnullvm" @@ -2368,9 +2458,9 @@ checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" [[package]] name = "windows_x86_64_gnullvm" -version = "0.52.3" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42079295511643151e98d61c38c0acc444e52dd42ab456f7ccfd5152e8ecf21c" +checksum = "77ca79f2451b49fa9e2af39f0747fe999fcda4f5e241b2898624dca97a1f2177" [[package]] name = "windows_x86_64_msvc" @@ -2386,9 +2476,9 @@ checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "windows_x86_64_msvc" -version = "0.52.3" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0770833d60a970638e989b3fa9fd2bb1aaadcf88963d1659fd7d9990196ed2d6" +checksum = "32b752e52a2da0ddfbdbcc6fceadfeede4c939ed16d13e648833a61dfb611ed8" [[package]] name = "winnow" @@ -2416,7 +2506,7 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.51", + "syn 2.0.52", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index eab5ae258..1b9d97a6d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,5 +32,7 @@ lto = "fat" lto = "thin" [patch.crates-io] -plonky2 = { git = "https://github.com/0xmozak/plonky2.git" } -starky = { git = "https://github.com/0xmozak/plonky2.git" } +plonky2 = { git = "https://github.com/0xmozak/plonky2.git", branch = "vivek/cuda" } +starky = { git = "https://github.com/0xmozak/plonky2.git", branch = "vivek/cuda" } +# plonky2 = { path = "../plonky2/plonky2" } +# starky = { path = "../plonky2/starky" } diff --git a/circuits/Cargo.toml b/circuits/Cargo.toml index fff9a8f06..8662a1aad 100644 --- a/circuits/Cargo.toml +++ b/circuits/Cargo.toml @@ -26,10 +26,12 @@ serde = { version = "1.0", features = ["derive"] } starky = { version = "0", default-features = false, features = ["std"] } thiserror = "1.0" tt-call = "1.0" +rustacuda = "0.1.3" +rustacuda_core = "0.1.2" +env_logger = { version = "0.10" } [dev-dependencies] criterion = { version = "0.5", features = ["html_reports"] } -env_logger = { version = "0.10" } hex = "0.4" im = "15.1" mozak-examples = { path = "../examples-builder", features = ["fibonacci", "fibonacci-input-new-api"] } @@ -41,6 +43,7 @@ enable_poseidon_starks = [] enable_register_starks = [] test = [] timing = ["plonky2/timing", "starky/timing"] +cuda = ["plonky2/cuda"] [[test]] name = "riscv_tests" diff --git a/circuits/src/cpu/add.rs b/circuits/src/cpu/add.rs index a0f2aae67..9aaeb1ac7 100644 --- a/circuits/src/cpu/add.rs +++ b/circuits/src/cpu/add.rs @@ -77,6 +77,11 @@ mod tests { Stark::prove_and_verify(&program, &record).unwrap(); } + #[test] + fn prove_add_cuda() { + prove_add::>(90, 90, 5); + } + use proptest::prelude::ProptestConfig; use proptest::proptest; proptest! { diff --git a/circuits/src/lib.rs b/circuits/src/lib.rs index eca935be8..adc13677e 100644 --- a/circuits/src/lib.rs +++ b/circuits/src/lib.rs @@ -6,6 +6,7 @@ #![allow(clippy::missing_errors_doc)] // FIXME: Remove this, when proptest's macro is updated not to trigger clippy. #![allow(clippy::ignored_unit_patterns)] +#![feature(allocator_api)] pub mod bitshift; pub mod columns_view; diff --git a/circuits/src/stark/prover.rs b/circuits/src/stark/prover.rs index 9fd8c130a..69d9a34b6 100644 --- a/circuits/src/stark/prover.rs +++ b/circuits/src/stark/prover.rs @@ -31,6 +31,7 @@ use crate::stark::mozak_stark::{all_starks, PublicInputs}; use crate::stark::permutation::challenge::GrandProductChallengeTrait; use crate::stark::poly::compute_quotient_polys; use crate::stark::proof::StarkProofWithMetadata; +use plonky2::fri::oracle::CudaInvContext; /// Prove the execution of a given [Program] /// @@ -48,6 +49,7 @@ pub fn prove( config: &StarkConfig, public_inputs: PublicInputs, timing: &mut TimingTree, + ctx: &mut Option<&mut CudaInvContext> ) -> Result> where F: RichField + Extendable, @@ -63,6 +65,7 @@ where public_inputs, &traces_poly_values, timing, + ctx, ) } @@ -76,14 +79,45 @@ pub fn prove_with_traces( public_inputs: PublicInputs, traces_poly_values: &TableKindArray>>, timing: &mut TimingTree, + ctx: &mut Option<&mut CudaInvContext> ) -> Result> where F: RichField + Extendable, C: GenericConfig, { let rate_bits = config.fri_config.rate_bits; let cap_height = config.fri_config.cap_height; + let trace_commitments; - let trace_commitments = timed!( + #[cfg(feature = "cuda")] + { + trace_commitments = timed!( + timing, + "Compute trace commitments for each table", + traces_poly_values + .clone() + .with_kind() + .map(|(trace, table)| { + timed!( + timing, + &format!("compute trace commitment for {table:?}"), + PolynomialBatch::::from_values_cuda( + trace.clone(), + rate_bits, + false, + cap_height, + timing, + trace.len(), + trace.first().expect("Not a single polynomial").len(), + ctx.as_mut().unwrap(), + ) + ) + }) + ); + } + + #[cfg(not(feature = "cuda"))] + { + trace_commitments = timed!( timing, "Compute trace commitments for each table", traces_poly_values @@ -104,6 +138,8 @@ where ) }) ); + } + // log::info!("trace_commitments {:?}", trace_commitments); let trace_caps = trace_commitments .each_ref() @@ -124,6 +160,23 @@ where &ctl_challenges ) ); + #[cfg(feature = "cuda")] + let proofs_with_metadata = timed!( + timing, + "compute all proofs given commitments", + prove_with_commitments( + mozak_stark, + config, + &public_inputs, + traces_poly_values, + &trace_commitments, + &ctl_data_per_table, + &mut challenger, + timing, + ctx, + )? + ); + #[cfg(not(feature = "cuda"))] let proofs_with_metadata = timed!( timing, "compute all proofs given commitments", @@ -135,7 +188,8 @@ where &trace_commitments, &ctl_data_per_table, &mut challenger, - timing + timing, + &mut None )? ); @@ -171,6 +225,7 @@ pub(crate) fn prove_single_table( ctl_data: &CtlData, challenger: &mut Challenger, timing: &mut TimingTree, + ctx: &mut Option<&mut CudaInvContext>, ) -> Result> where F: RichField + Extendable, @@ -301,6 +356,7 @@ where challenger, &fri_params, timing, + ctx, ) ); @@ -332,6 +388,7 @@ pub fn prove_with_commitments( ctl_data_per_table: &TableKindArray>, challenger: &mut Challenger, timing: &mut TimingTree, + ctx: &mut Option<&mut CudaInvContext> ) -> Result>> where F: RichField + Extendable, @@ -353,6 +410,7 @@ where &ctl_data_per_table[kind], challenger, timing, + ctx, )? })) } @@ -480,4 +538,63 @@ mod tests { }, ]); } + #[test] + #[cfg(feature = "cuda")] + fn test_cuda_poly_batch() { + use plonky2::fri::oracle::PolynomialBatch; + use plonky2::util::timing::TimingTree; + use plonky2::plonk::config::GenericConfig; + use plonky2::plonk::config::PoseidonGoldilocksConfig; + use plonky2::field::types::Sample; + use plonky2::field::polynomial::PolynomialValues; + + +const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + // use plonky2::field::fft::fft_root_table; + let values_num_per_poly = 1 << 6; + let poly_num = 8; + let mut polys = vec![]; + for _i in 0..poly_num { + let poly: Vec = (0..values_num_per_poly).map(|_| F::rand()).collect(); + let poly_as_value = PolynomialValues::new(poly); + polys.push(poly_as_value); + } + let rate_bits = 3; + let cap_height = 4; + let len_cap = 1 << cap_height; + let _all_len = poly_num * values_num_per_poly * (1 << rate_bits); + let num_digests = 2 * (values_num_per_poly * (1 << rate_bits) - len_cap); + let _num_digests_and_caps = num_digests + len_cap; + let blinding = false; + let timing = &mut TimingTree::default(); + let batch: PolynomialBatch = PolynomialBatch::from_values( + polys.clone(), + rate_bits, + blinding, + cap_height, + timing, + None, + ); + let ctx = &mut crate::test_utils::cuda_ctx(); + let cuda_batch: PolynomialBatch = PolynomialBatch::from_values_cuda( + polys, + rate_bits, + blinding, + cap_height, + timing, + poly_num, + values_num_per_poly, + ctx, + ); + assert_eq!(batch.polynomials, cuda_batch.polynomials); + let leaves = batch.merkle_tree.leaves.into_iter().flatten().collect::>(); + assert_eq!(leaves, *cuda_batch.merkle_tree.my_leaves); + assert_eq!(batch.merkle_tree.cap, cuda_batch.merkle_tree.cap); + assert_eq!(batch.merkle_tree.digests.len(), cuda_batch.merkle_tree.my_digests.len()); + assert_eq!(batch.merkle_tree.digests, *cuda_batch.merkle_tree.my_digests); + } + + } diff --git a/circuits/src/stark/recursive_verifier.rs b/circuits/src/stark/recursive_verifier.rs index bd949d677..610dc11cc 100644 --- a/circuits/src/stark/recursive_verifier.rs +++ b/circuits/src/stark/recursive_verifier.rs @@ -608,6 +608,7 @@ where } #[cfg(test)] +#[allow(unused_imports)] mod tests { use std::panic; use std::panic::AssertUnwindSafe; @@ -639,6 +640,9 @@ mod tests { #[test] #[ignore] fn recursive_verify_mozak_starks() -> Result<()> { + #[cfg(not(feature = "cuda"))] + { + type S = MozakStark; let stark = S::default(); let mut config = StarkConfig::standard_fast_config(); config.fri_config.cap_height = 1; @@ -678,7 +682,9 @@ mod tests { ); let recursive_proof = mozak_stark_circuit.prove(&mozak_proof)?; - mozak_stark_circuit.circuit.verify(recursive_proof) + mozak_stark_circuit.circuit.verify(recursive_proof)?; + } + Ok(()) } #[test] @@ -708,6 +714,7 @@ mod tests { &stark_config0, public_inputs, &mut TimingTree::default(), + &mut None, )?; let (program1, record1) = execute_code(vec![inst; 128], &[], &[(6, 100), (7, 200)]); @@ -722,6 +729,7 @@ mod tests { &stark_config1, public_inputs, &mut TimingTree::default(), + &mut None, )?; // The degree bits should be different for the two proofs. diff --git a/circuits/src/stark/serde.rs b/circuits/src/stark/serde.rs index 4adbba7ea..8ab17489c 100644 --- a/circuits/src/stark/serde.rs +++ b/circuits/src/stark/serde.rs @@ -32,6 +32,7 @@ impl, C: GenericConfig, const D: usize> A } #[cfg(test)] +#[allow(unused_imports)] mod tests { use mozak_runner::util::execute_code; use plonky2::util::timing::TimingTree; @@ -45,13 +46,15 @@ mod tests { #[test] fn test_serialization_deserialization() { + + #[cfg(not(feature = "cuda"))] + { let (program, record) = execute_code([], &[], &[]); let stark = S::default(); let config = fast_test_config(); let public_inputs = PublicInputs { entry_point: from_u32(program.entry_point), }; - let all_proof = prove::( &program, &record, @@ -68,5 +71,6 @@ mod tests { AllProof::::deserialize_proof_from_flexbuffer(s.view()) .expect("deserialization failed"); verify_proof(&stark, all_proof_deserialized, &config).unwrap(); + } } } diff --git a/circuits/src/stark/verifier.rs b/circuits/src/stark/verifier.rs index d190e4042..a6ebba25a 100644 --- a/circuits/src/stark/verifier.rs +++ b/circuits/src/stark/verifier.rs @@ -76,6 +76,7 @@ where public_inputs[kind], &ctl_vars_per_table[kind], config, + kind, )?; }); verify_cross_table_lookups::( @@ -98,6 +99,7 @@ pub(crate) fn verify_stark_proof_with_challenges< public_inputs: &[F], ctl_vars: &[CtlCheckVars], config: &StarkConfig, + kind: TableKind, ) -> Result<()> where { @@ -157,9 +159,12 @@ where .chunks(stark.quotient_degree_factor()) .enumerate() { + let van_poly_zeta = vanishing_polys_zeta[i]; + let opening_poly = z_h_zeta * reduce_with_powers(chunk, zeta_pow_deg); + log::info!("evaluation {:?}, quotient {:?}", van_poly_zeta, opening_poly); ensure!( - vanishing_polys_zeta[i] == z_h_zeta * reduce_with_powers(chunk, zeta_pow_deg), - "Mismatch between evaluation and opening of quotient polynomial" + van_poly_zeta == opening_poly, + "Mismatch between evaluation and opening of quotient polynomial {} {:?} {:?} {:?}", i, kind, van_poly_zeta, opening_poly ); } diff --git a/circuits/src/test_utils.rs b/circuits/src/test_utils.rs index f71e8b2d1..c2d87aa34 100644 --- a/circuits/src/test_utils.rs +++ b/circuits/src/test_utils.rs @@ -11,6 +11,7 @@ use mozak_system::system::reg_abi::{REG_A0, REG_A1, REG_A2, REG_A3}; use plonky2::field::goldilocks_field::GoldilocksField; use plonky2::field::types::Field; use plonky2::fri::FriConfig; +use plonky2::fri::oracle::CudaInvContext; use plonky2::hash::hash_types::{HashOut, RichField}; use plonky2::hash::poseidon2::Poseidon2Hash; use plonky2::plonk::config::{GenericConfig, Hasher, Poseidon2GoldilocksConfig}; @@ -20,6 +21,10 @@ use starky::config::StarkConfig; use starky::prover::prove as prove_table; use starky::stark::Stark; use starky::verifier::verify_stark_proof; +use plonky2::util::log2_strict; +use plonky2::field::fft::fft_root_table; +use plonky2::fri::oracle::CudaInnerContext; + use crate::bitshift::stark::BitshiftStark; use crate::cpu::stark::CpuStark; @@ -52,6 +57,8 @@ use crate::stark::utils::{trace_rows_to_poly_values, trace_to_poly_values}; use crate::stark::verifier::verify_proof; use crate::utils::from_u32; use crate::xor::stark::XorStark; +use rustacuda::prelude::*; +use rustacuda::memory::DeviceBuffer; pub type S = MozakStark; pub const D: usize = 2; @@ -367,6 +374,52 @@ impl ProveAndVerify for MozakStark { } } +#[cfg(feature = "cuda")] +pub fn cuda_ctx() -> CudaInvContext { + + let ctx; + rustacuda::init(CudaFlags::empty()).unwrap(); + let device_index = 0; + let device = rustacuda::prelude::Device::get_device(device_index).unwrap(); + let _ctx = Context::create_and_push(ContextFlags::MAP_HOST | ContextFlags::SCHED_AUTO, device).unwrap(); + let stream = Stream::new(StreamFlags::NON_BLOCKING, None).unwrap(); + let stream2 = Stream::new(StreamFlags::NON_BLOCKING, None).unwrap(); + + let values_num_per_poly = 1<<20; + let _blinding = false; + const _SALT_SIZE: usize = 4; + let rate_bits = 2; + let _cap_height = 4; + + let lg_n = log2_strict(values_num_per_poly); + let fft_root_table_deg = fft_root_table(1 << lg_n).concat(); + let fft_root_table_max = fft_root_table(1<<(lg_n + rate_bits)).concat(); + + + + let root_table_device = { + let root_table_device = DeviceBuffer::from_slice(&fft_root_table_deg).unwrap(); + root_table_device + }; + let root_table_device2 = { + let root_table_device = DeviceBuffer::from_slice(&fft_root_table_max).unwrap(); + root_table_device + }; + let shift_powers = F::coset_shift().powers().take(1<<(lg_n)).collect::>(); + let shift_powers_device = { + let shift_powers_device = DeviceBuffer::from_slice(&shift_powers).unwrap(); + shift_powers_device + }; + + ctx = plonky2::fri::oracle::CudaInvContext{ + inner: CudaInnerContext{stream, stream2,}, + root_table_device, + root_table_device2, + shift_powers_device, + ctx: _ctx, + }; + return ctx; +} pub fn prove_and_verify_mozak_stark( program: &Program, record: &ExecutionRecord, @@ -376,7 +429,28 @@ pub fn prove_and_verify_mozak_stark( let public_inputs = PublicInputs { entry_point: from_u32(program.entry_point), }; + + #[cfg(feature = "cuda")] + { + use plonky2::plonk::config::PoseidonGoldilocksConfig; +pub type C = PoseidonGoldilocksConfig; + let mut ctx = cuda_ctx(); + let all_proof = prove::( + program, + record, + &stark, + config, + public_inputs, + &mut TimingTree::default(), + &mut Some(&mut ctx), + )?; + verify_proof(&stark, all_proof, config) + // Ok(()) + } + + #[cfg(not(feature = "cuda"))] + { let all_proof = prove::( program, record, @@ -384,8 +458,10 @@ pub fn prove_and_verify_mozak_stark( config, public_inputs, &mut TimingTree::default(), + &mut None, )?; verify_proof(&stark, all_proof, config) + } } /// Interpret a u64 as a field element and try to invert it. diff --git a/cli/Cargo.toml b/cli/Cargo.toml index aa37f5e2a..5d805c7d4 100644 --- a/cli/Cargo.toml +++ b/cli/Cargo.toml @@ -37,3 +37,6 @@ tempfile = "3" mozak-circuits = { path = "../circuits", features = ["test"] } mozak-runner = { path = "../runner", features = ["test"] } proptest = "1.4" + +[features] +cuda = ["mozak-circuits/cuda"] diff --git a/cli/src/main.rs b/cli/src/main.rs index 36b79284b..308599ca2 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -23,6 +23,9 @@ use mozak_circuits::stark::recursive_verifier::{ use mozak_circuits::stark::utils::trace_rows_to_poly_values; use mozak_circuits::stark::verifier::verify_proof; use mozak_circuits::test_utils::{prove_and_verify_mozak_stark, C, D, F, S}; +#[cfg(feature = "cuda")] +use mozak_circuits::test_utils::cuda_ctx; + use mozak_cli::cli_benches::benches::BenchArgs; use mozak_runner::elf::Program; use mozak_runner::state::State; @@ -186,6 +189,10 @@ fn main() -> Result<()> { let public_inputs = PublicInputs { entry_point: F::from_canonical_u32(program.entry_point), }; + + #[cfg(feature = "cuda")] + { + let mut ctx =cuda_ctx(); let all_proof = prove::( &program, &record, @@ -193,10 +200,10 @@ fn main() -> Result<()> { &config, public_inputs, &mut TimingTree::default(), + &mut Some(&mut ctx), )?; let s = all_proof.serialize_proof_to_flexbuffer()?; proof.write_all(s.view())?; - // Generate recursive proof if let Some(mut recursive_proof_output) = recursive_proof { let degree_bits = all_proof.degree_bits(&config); @@ -231,6 +238,45 @@ fn main() -> Result<()> { let bytes = final_circuit.circuit.verifier_only.to_bytes().unwrap(); vk_output.write_all(&bytes)?; } + } + + #[cfg(not(feature = "cuda"))] + { + let all_proof = prove::( + &program, + &record, + &stark, + &config, + public_inputs, + &mut TimingTree::default(), + )?; + let s = all_proof.serialize_proof_to_flexbuffer()?; + proof.write_all(s.view())?; + // Generate recursive proof + if let Some(mut recursive_proof_output) = recursive_proof { + let circuit_config = CircuitConfig::standard_recursion_config(); + let degree_bits = all_proof.degree_bits(&config); + let recursive_circuit = recursive_mozak_stark_circuit::( + &stark, + °ree_bits, + &circuit_config, + &config, + ); + + let recursive_all_proof = recursive_circuit.prove(&all_proof)?; + let s = recursive_all_proof.to_bytes(); + recursive_proof_output.write_all(&s)?; + + // Generate the degree bits file + let mut degree_bits_output_path = recursive_proof_output.path().clone(); + degree_bits_output_path.set_extension("db"); + let mut degree_bits_output = degree_bits_output_path.create()?; + + let serialized = serde_json::to_string(°ree_bits)?; + degree_bits_output.write_all(serialized.as_bytes())?; + } + } + debug!("proof generated successfully!"); } From dfcb95970b7378739182abb0c92b187094c74a0d Mon Sep 17 00:00:00 2001 From: codeblooded1729 Date: Mon, 4 Mar 2024 13:57:34 +0530 Subject: [PATCH 2/3] add docs --- circuits/src/stark/prover.rs | 144 ++++++++++++++++++++--------------- 1 file changed, 83 insertions(+), 61 deletions(-) diff --git a/circuits/src/stark/prover.rs b/circuits/src/stark/prover.rs index 69d9a34b6..e0735d3f3 100644 --- a/circuits/src/stark/prover.rs +++ b/circuits/src/stark/prover.rs @@ -11,7 +11,7 @@ use plonky2::field::extension::Extendable; use plonky2::field::packable::Packable; use plonky2::field::polynomial::PolynomialValues; use plonky2::field::types::Field; -use plonky2::fri::oracle::PolynomialBatch; +use plonky2::fri::oracle::{CudaInvContext, PolynomialBatch}; use plonky2::hash::hash_types::RichField; use plonky2::iop::challenger::Challenger; use plonky2::plonk::config::GenericConfig; @@ -31,7 +31,6 @@ use crate::stark::mozak_stark::{all_starks, PublicInputs}; use crate::stark::permutation::challenge::GrandProductChallengeTrait; use crate::stark::poly::compute_quotient_polys; use crate::stark::proof::StarkProofWithMetadata; -use plonky2::fri::oracle::CudaInvContext; /// Prove the execution of a given [Program] /// @@ -49,7 +48,7 @@ pub fn prove( config: &StarkConfig, public_inputs: PublicInputs, timing: &mut TimingTree, - ctx: &mut Option<&mut CudaInvContext> + ctx: &mut Option<&mut CudaInvContext>, ) -> Result> where F: RichField + Extendable, @@ -79,7 +78,7 @@ pub fn prove_with_traces( public_inputs: PublicInputs, traces_poly_values: &TableKindArray>>, timing: &mut TimingTree, - ctx: &mut Option<&mut CudaInvContext> + ctx: &mut Option<&mut CudaInvContext>, ) -> Result> where F: RichField + Extendable, @@ -90,54 +89,55 @@ where #[cfg(feature = "cuda")] { - trace_commitments = timed!( - timing, - "Compute trace commitments for each table", - traces_poly_values - .clone() - .with_kind() - .map(|(trace, table)| { - timed!( - timing, - &format!("compute trace commitment for {table:?}"), - PolynomialBatch::::from_values_cuda( - trace.clone(), - rate_bits, - false, - cap_height, + trace_commitments = timed!( + timing, + "Compute trace commitments for each table", + traces_poly_values + .clone() + .with_kind() + .map(|(trace, table)| { + timed!( timing, - trace.len(), - trace.first().expect("Not a single polynomial").len(), - ctx.as_mut().unwrap(), + &format!("compute trace commitment for {table:?}"), + // creates merkle tree out of trace polynomials over gpu + PolynomialBatch::::from_values_cuda( + trace.clone(), + rate_bits, + false, + cap_height, + timing, + trace.len(), + trace.first().expect("Not a single polynomial").len(), + ctx.as_mut().unwrap(), + ) ) - ) - }) - ); + }) + ); } #[cfg(not(feature = "cuda"))] { - trace_commitments = timed!( - timing, - "Compute trace commitments for each table", - traces_poly_values - .clone() - .with_kind() - .map(|(trace, table)| { - timed!( - timing, - &format!("compute trace commitment for {table:?}"), - PolynomialBatch::::from_values( - trace.clone(), - rate_bits, - false, - cap_height, + trace_commitments = timed!( + timing, + "Compute trace commitments for each table", + traces_poly_values + .clone() + .with_kind() + .map(|(trace, table)| { + timed!( timing, - None, + &format!("compute trace commitment for {table:?}"), + PolynomialBatch::::from_values( + trace.clone(), + rate_bits, + false, + cap_height, + timing, + None, + ) ) - ) - }) - ); + }) + ); } // log::info!("trace_commitments {:?}", trace_commitments); @@ -388,7 +388,7 @@ pub fn prove_with_commitments( ctl_data_per_table: &TableKindArray>, challenger: &mut Challenger, timing: &mut TimingTree, - ctx: &mut Option<&mut CudaInvContext> + ctx: &mut Option<&mut CudaInvContext>, ) -> Result>> where F: RichField + Extendable, @@ -538,37 +538,45 @@ mod tests { }, ]); } + + /// Test for ensuring Polynomial batch computed over gpu is same as + /// that computed over cpu #[test] #[cfg(feature = "cuda")] fn test_cuda_poly_batch() { + use plonky2::field::polynomial::PolynomialValues; + use plonky2::field::types::Sample; use plonky2::fri::oracle::PolynomialBatch; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; use plonky2::util::timing::TimingTree; - use plonky2::plonk::config::GenericConfig; - use plonky2::plonk::config::PoseidonGoldilocksConfig; - use plonky2::field::types::Sample; - use plonky2::field::polynomial::PolynomialValues; - - -const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - // use plonky2::field::fft::fft_root_table; + + const D: usize = 2; + // the cuda code only supports Poseidon for now (Not Poseidon2!) + type C = PoseidonGoldilocksConfig; + type F = >::F; + + // number of rows of trace table let values_num_per_poly = 1 << 6; + // number of columns of trace table let poly_num = 8; + // generate random trace let mut polys = vec![]; for _i in 0..poly_num { let poly: Vec = (0..values_num_per_poly).map(|_| F::rand()).collect(); let poly_as_value = PolynomialValues::new(poly); polys.push(poly_as_value); } - let rate_bits = 3; + // rate_bits = log2ceil(constraint_degree) + let rate_bits = 2; let cap_height = 4; let len_cap = 1 << cap_height; + // flattened len of all lde polynomials let _all_len = poly_num * values_num_per_poly * (1 << rate_bits); let num_digests = 2 * (values_num_per_poly * (1 << rate_bits) - len_cap); let _num_digests_and_caps = num_digests + len_cap; let blinding = false; let timing = &mut TimingTree::default(); + // merkle tree over over cpu let batch: PolynomialBatch = PolynomialBatch::from_values( polys.clone(), rate_bits, @@ -578,6 +586,7 @@ const D: usize = 2; None, ); let ctx = &mut crate::test_utils::cuda_ctx(); + // merkle tree over gpu let cuda_batch: PolynomialBatch = PolynomialBatch::from_values_cuda( polys, rate_bits, @@ -588,13 +597,26 @@ const D: usize = 2; values_num_per_poly, ctx, ); + + // check that polynomials were computed in coefficient form + // are same for cpu and gpu. assert_eq!(batch.polynomials, cuda_batch.polynomials); - let leaves = batch.merkle_tree.leaves.into_iter().flatten().collect::>(); + let leaves = batch + .merkle_tree + .leaves + .into_iter() + .flatten() + .collect::>(); + // check that merkle tree computed over cpu is same as gpu assert_eq!(leaves, *cuda_batch.merkle_tree.my_leaves); assert_eq!(batch.merkle_tree.cap, cuda_batch.merkle_tree.cap); - assert_eq!(batch.merkle_tree.digests.len(), cuda_batch.merkle_tree.my_digests.len()); - assert_eq!(batch.merkle_tree.digests, *cuda_batch.merkle_tree.my_digests); + assert_eq!( + batch.merkle_tree.digests.len(), + cuda_batch.merkle_tree.my_digests[..num_digests].len() + ); + assert_eq!( + batch.merkle_tree.digests, + cuda_batch.merkle_tree.my_digests[..num_digests] + ); } - - } From b1a896f05516fdf6474014a4ddd8b91d05be7271 Mon Sep 17 00:00:00 2001 From: codeblooded1729 Date: Mon, 4 Mar 2024 14:03:00 +0530 Subject: [PATCH 3/3] add more docs to ctx generation --- circuits/src/test_utils.rs | 150 +++++++++++++++++++------------------ 1 file changed, 76 insertions(+), 74 deletions(-) diff --git a/circuits/src/test_utils.rs b/circuits/src/test_utils.rs index c2d87aa34..6cdcdd713 100644 --- a/circuits/src/test_utils.rs +++ b/circuits/src/test_utils.rs @@ -8,23 +8,22 @@ use mozak_runner::util::execute_code; use mozak_runner::vm::ExecutionRecord; use mozak_system::system::ecall; use mozak_system::system::reg_abi::{REG_A0, REG_A1, REG_A2, REG_A3}; +use plonky2::field::fft::fft_root_table; use plonky2::field::goldilocks_field::GoldilocksField; use plonky2::field::types::Field; +use plonky2::fri::oracle::{CudaInnerContext, CudaInvContext}; use plonky2::fri::FriConfig; -use plonky2::fri::oracle::CudaInvContext; use plonky2::hash::hash_types::{HashOut, RichField}; use plonky2::hash::poseidon2::Poseidon2Hash; use plonky2::plonk::config::{GenericConfig, Hasher, Poseidon2GoldilocksConfig}; -use plonky2::util::log2_ceil; use plonky2::util::timing::TimingTree; +use plonky2::util::{log2_ceil, log2_strict}; +use rustacuda::memory::DeviceBuffer; +use rustacuda::prelude::*; use starky::config::StarkConfig; use starky::prover::prove as prove_table; use starky::stark::Stark; use starky::verifier::verify_stark_proof; -use plonky2::util::log2_strict; -use plonky2::field::fft::fft_root_table; -use plonky2::fri::oracle::CudaInnerContext; - use crate::bitshift::stark::BitshiftStark; use crate::cpu::stark::CpuStark; @@ -57,8 +56,6 @@ use crate::stark::utils::{trace_rows_to_poly_values, trace_to_poly_values}; use crate::stark::verifier::verify_proof; use crate::utils::from_u32; use crate::xor::stark::XorStark; -use rustacuda::prelude::*; -use rustacuda::memory::DeviceBuffer; pub type S = MozakStark; pub const D: usize = 2; @@ -376,49 +373,55 @@ impl ProveAndVerify for MozakStark { #[cfg(feature = "cuda")] pub fn cuda_ctx() -> CudaInvContext { + let ctx; + rustacuda::init(CudaFlags::empty()).unwrap(); + let device_index = 0; + let device = rustacuda::prelude::Device::get_device(device_index).unwrap(); + // TODO(Kapil): is this required? + let _ctx = Context::create_and_push(ContextFlags::MAP_HOST | ContextFlags::SCHED_AUTO, device) + .unwrap(); + let stream = Stream::new(StreamFlags::NON_BLOCKING, None).unwrap(); + // TODO(Kapil): figure out if this is necessary for now? + let stream2 = Stream::new(StreamFlags::NON_BLOCKING, None).unwrap(); + // max size of number of rows + let values_num_per_poly = 1 << 20; + let _blinding = false; + const _SALT_SIZE: usize = 4; + let rate_bits = 2; + let _cap_height = 4; + + let lg_n = log2_strict(values_num_per_poly); + let fft_root_table_deg = fft_root_table(1 << lg_n).concat(); + let fft_root_table_max = fft_root_table(1 << (lg_n + rate_bits)).concat(); + + // would be used for fft over trace domain + let root_table_device = { + let root_table_device = DeviceBuffer::from_slice(&fft_root_table_deg).unwrap(); + root_table_device + }; + // used for fft over lde domain + let root_table_device2 = { + let root_table_device = DeviceBuffer::from_slice(&fft_root_table_max).unwrap(); + root_table_device + }; + let shift_powers = F::coset_shift() + .powers() + .take(1 << (lg_n)) + .collect::>(); + // used for coset fft over lde domain + let shift_powers_device = { + let shift_powers_device = DeviceBuffer::from_slice(&shift_powers).unwrap(); + shift_powers_device + }; - let ctx; - rustacuda::init(CudaFlags::empty()).unwrap(); - let device_index = 0; - let device = rustacuda::prelude::Device::get_device(device_index).unwrap(); - let _ctx = Context::create_and_push(ContextFlags::MAP_HOST | ContextFlags::SCHED_AUTO, device).unwrap(); - let stream = Stream::new(StreamFlags::NON_BLOCKING, None).unwrap(); - let stream2 = Stream::new(StreamFlags::NON_BLOCKING, None).unwrap(); - - let values_num_per_poly = 1<<20; - let _blinding = false; - const _SALT_SIZE: usize = 4; - let rate_bits = 2; - let _cap_height = 4; - - let lg_n = log2_strict(values_num_per_poly); - let fft_root_table_deg = fft_root_table(1 << lg_n).concat(); - let fft_root_table_max = fft_root_table(1<<(lg_n + rate_bits)).concat(); - - - - let root_table_device = { - let root_table_device = DeviceBuffer::from_slice(&fft_root_table_deg).unwrap(); - root_table_device - }; - let root_table_device2 = { - let root_table_device = DeviceBuffer::from_slice(&fft_root_table_max).unwrap(); - root_table_device - }; - let shift_powers = F::coset_shift().powers().take(1<<(lg_n)).collect::>(); - let shift_powers_device = { - let shift_powers_device = DeviceBuffer::from_slice(&shift_powers).unwrap(); - shift_powers_device - }; - - ctx = plonky2::fri::oracle::CudaInvContext{ - inner: CudaInnerContext{stream, stream2,}, - root_table_device, - root_table_device2, - shift_powers_device, - ctx: _ctx, - }; - return ctx; + ctx = plonky2::fri::oracle::CudaInvContext { + inner: CudaInnerContext { stream, stream2 }, + root_table_device, + root_table_device2, + shift_powers_device, + ctx: _ctx, + }; + return ctx; } pub fn prove_and_verify_mozak_stark( program: &Program, @@ -429,38 +432,37 @@ pub fn prove_and_verify_mozak_stark( let public_inputs = PublicInputs { entry_point: from_u32(program.entry_point), }; - + #[cfg(feature = "cuda")] { use plonky2::plonk::config::PoseidonGoldilocksConfig; -pub type C = PoseidonGoldilocksConfig; + pub type C = PoseidonGoldilocksConfig; let mut ctx = cuda_ctx(); - let all_proof = prove::( - program, - record, - &stark, - config, - public_inputs, - &mut TimingTree::default(), - &mut Some(&mut ctx), - )?; - verify_proof(&stark, all_proof, config) - // Ok(()) - + let all_proof = prove::( + program, + record, + &stark, + config, + public_inputs, + &mut TimingTree::default(), + &mut Some(&mut ctx), + )?; + verify_proof(&stark, all_proof, config) + // Ok(()) } #[cfg(not(feature = "cuda"))] { - let all_proof = prove::( - program, - record, - &stark, - config, - public_inputs, - &mut TimingTree::default(), - &mut None, - )?; - verify_proof(&stark, all_proof, config) + let all_proof = prove::( + program, + record, + &stark, + config, + public_inputs, + &mut TimingTree::default(), + &mut None, + )?; + verify_proof(&stark, all_proof, config) } }