Skip to content

Commit 0637395

Browse files
authored
fix(audit): refactor pad_variable_length for sha256 + sha512 (#342)
1 parent 3b8ba82 commit 0637395

File tree

5 files changed

+184
-145
lines changed

5 files changed

+184
-145
lines changed

plonky2x/core/src/frontend/hash/sha/sha256/curta.rs

+7-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ impl<L: PlonkParameters<D>, const D: usize> Hash<L, D, 64, false, 8> for SHA256
5757
input: &[ByteVariable],
5858
length: U32Variable,
5959
) -> Vec<Self::IntVariable> {
60-
let padded_bytes = builder.pad_message_sha256_variable(input, length);
60+
let padded_bytes = builder.pad_sha256_variable_length(input, length);
6161

6262
padded_bytes
6363
.chunks_exact(4)
@@ -148,6 +148,12 @@ impl<L: PlonkParameters<D>, const D: usize> CircuitBuilder<L, D> {
148148
0,
149149
"input length should be a multiple of 64"
150150
);
151+
152+
// Check that length <= input.len(). This is needed to ensure that users cannot prove the
153+
// hash of a longer message than they supplied.
154+
let supplied_input_length = self.constant::<U32Variable>(input.len() as u32);
155+
self.lte(length, supplied_input_length);
156+
151157
let last_chunk = self.compute_sha256_last_chunk(length);
152158
if self.sha256_accelerator.is_none() {
153159
self.sha256_accelerator = Some(SHA256Accelerator {

plonky2x/core/src/frontend/hash/sha/sha256/mod.rs

+1-6
Original file line numberDiff line numberDiff line change
@@ -161,16 +161,13 @@ mod tests {
161161
use super::*;
162162
use crate::prelude::{ByteVariable, CircuitBuilder, DefaultParameters, U32Variable};
163163
use crate::utils::hash::sha256;
164-
use crate::utils::setup_logger;
165164

166165
type L = DefaultParameters;
167166
const D: usize = 2;
168167

169168
#[test]
170169
#[cfg_attr(feature = "ci", ignore)]
171170
fn test_sha256_padding() {
172-
setup_logger();
173-
174171
let mut builder = CircuitBuilder::<L, D>::new();
175172

176173
let max_len = 1024;
@@ -209,8 +206,6 @@ mod tests {
209206
#[test]
210207
#[cfg_attr(feature = "ci", ignore)]
211208
fn test_sha256_variable_padding() {
212-
setup_logger();
213-
214209
let mut builder = CircuitBuilder::<L, D>::new();
215210

216211
let max_number_of_chunks = 5;
@@ -239,7 +234,7 @@ mod tests {
239234
.map(|b| builder.constant::<ByteVariable>(*b))
240235
.collect::<Vec<_>>();
241236

242-
let padding = builder.pad_message_sha256_variable(&message, length);
237+
let padding = builder.pad_sha256_variable_length(&message, length);
243238

244239
for (value, expected) in padding.iter().zip(expected_padding.iter()) {
245240
builder.assert_is_equal(*value, *expected);

plonky2x/core/src/frontend/hash/sha/sha256/pad.rs

+60-34
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
11
use array_macro::array;
2+
use plonky2::util::ceil_div_usize;
23

34
use crate::prelude::*;
45

6+
pub const SHA256_CHUNK_SIZE_BYTES: usize = 64;
7+
pub const SHA256_INPUT_LENGTH_BYTE_SIZE: usize = 8;
8+
9+
pub const SHA256_CHUNK_SIZE_BITS: usize = SHA256_CHUNK_SIZE_BYTES * 8;
10+
pub const SHA256_INPUT_LENGTH_BIT_SIZE: usize = SHA256_INPUT_LENGTH_BYTE_SIZE * 8;
11+
512
impl<L: PlonkParameters<D>, const D: usize> CircuitBuilder<L, D> {
613
/// Pad the given input according to the SHA-256 spec.
714
/// The last chunk (each chunk is 64 bytes = 512 bits) gets padded.
@@ -13,7 +20,8 @@ impl<L: PlonkParameters<D>, const D: usize> CircuitBuilder<L, D> {
1320
bits.push(self.api._true());
1421

1522
let l = bits.len() - 1;
16-
let k = 512 - (l + 1 + 64) % 512;
23+
let k = SHA256_CHUNK_SIZE_BITS
24+
- (l + 1 + SHA256_INPUT_LENGTH_BIT_SIZE) % SHA256_CHUNK_SIZE_BITS;
1725
for _ in 0..k {
1826
bits.push(self.api._false());
1927
}
@@ -51,53 +59,66 @@ impl<L: PlonkParameters<D>, const D: usize> CircuitBuilder<L, D> {
5159
}
5260

5361
/// Pad the given variable length input according to the SHA-256 spec.
54-
///
55-
/// It is assumed that `input` has length MAX_NUM_CHUNKS * 64.
56-
/// The true number of non-zero bytes in `input` is given by input_byte_length.
57-
pub(crate) fn pad_message_sha256_variable(
62+
/// input_byte_length gives the real length of the input in bytes.
63+
pub(crate) fn pad_sha256_variable_length(
5864
&mut self,
5965
input: &[ByteVariable],
6066
input_byte_length: U32Variable,
6167
) -> Vec<ByteVariable> {
62-
let max_number_of_chunks = input.len() / 64;
63-
assert_eq!(
64-
max_number_of_chunks * 64,
65-
input.len(),
66-
"input length must be a multiple of 64 bytes"
67-
);
68+
let true_t = self._true();
69+
let false_t = self._false();
70+
6871
let last_chunk = self.compute_sha256_last_chunk(input_byte_length);
6972

73+
// Calculate the number of chunks needed to store the input. 9 bytes are added by the
74+
// padding and LE length representation.
75+
let max_num_chunks = ceil_div_usize(
76+
input.len() + SHA256_INPUT_LENGTH_BYTE_SIZE + 1,
77+
SHA256_CHUNK_SIZE_BYTES,
78+
);
79+
80+
// Extend input to size max_num_chunks * 64 before padding.
81+
let mut padded_input = input.to_vec();
82+
padded_input.resize(max_num_chunks * SHA256_CHUNK_SIZE_BYTES, self.zero());
83+
7084
// Compute the length bytes (big-endian representation of the length in bits).
7185
let zero_byte = self.constant::<ByteVariable>(0x00);
72-
let mut length_bytes = vec![zero_byte; 4];
7386

7487
let bits_per_byte = self.constant::<U32Variable>(8);
7588
let input_bit_length = self.mul(input_byte_length, bits_per_byte);
7689

77-
let mut length_bits = self.to_le_bits(input_bit_length);
90+
// Get the length bits in LE order, padded to 64 bits.
91+
let mut length_bits = self
92+
.api
93+
.split_le(input_bit_length.variable.0, SHA256_INPUT_LENGTH_BIT_SIZE);
94+
// Convert length to BE bits
7895
length_bits.reverse();
7996

80-
// Prepend 4 zero bytes to length_bytes as abi.encodePacked(U32Variable) is 4 bytes.
81-
length_bytes.extend_from_slice(
82-
&length_bits
83-
.chunks(8)
84-
.map(|chunk| {
85-
let bits = array![x => chunk[x]; 8];
86-
ByteVariable(bits)
87-
})
88-
.collect::<Vec<_>>(),
89-
);
97+
let length_bytes = &length_bits
98+
.chunks(8)
99+
.map(|chunk| {
100+
let bits = array![x => BoolVariable::from_targets(&[chunk[x].target]); 8];
101+
ByteVariable(bits)
102+
})
103+
.collect::<Vec<_>>();
90104

91105
let mut padded_bytes = Vec::new();
92106

93-
let mut message_byte_selector = self.constant::<BoolVariable>(true);
94-
for i in 0..max_number_of_chunks {
95-
let chunk_offset = 64 * i;
107+
// Set to true if the last chunk has been reached. This is used to verify that
108+
// input_byte_length is <= input.len().
109+
let mut reached_last_chunk = false_t;
110+
111+
let mut message_byte_selector = true_t;
112+
for i in 0..max_num_chunks {
113+
let chunk_offset = SHA256_CHUNK_SIZE_BYTES * i;
96114
let curr_chunk = self.constant::<U32Variable>(i as u32);
97115

116+
// Check if this is the chunk where length should be added.
98117
let is_last_chunk = self.is_equal(curr_chunk, last_chunk);
118+
reached_last_chunk = self.or(reached_last_chunk, is_last_chunk);
99119

100-
for j in 0..64 {
120+
for j in 0..SHA256_CHUNK_SIZE_BYTES {
121+
// First 64 - 8 bytes are either message | padding | nil bytes.
101122
let idx = chunk_offset + j;
102123
let idx_t = self.constant::<U32Variable>(idx as u32);
103124
let is_last_msg_byte = self.is_equal(idx_t, input_byte_length);
@@ -112,21 +133,26 @@ impl<L: PlonkParameters<D>, const D: usize> CircuitBuilder<L, D> {
112133
let padding_start_byte = self.constant::<ByteVariable>(0x80);
113134

114135
// If message_byte_selector is true, select the message byte.
115-
let mut byte = self.select(message_byte_selector, input[idx], zero_byte);
136+
let mut byte = self.select(message_byte_selector, padded_input[idx], zero_byte);
116137
// If idx == length_bytes, select the padding start byte.
117138
byte = self.select(is_last_msg_byte, padding_start_byte, byte);
118-
if j >= 64 - 8 {
119-
// If in last chunk, select the length byte.
120-
byte = self.select(is_last_chunk, length_bytes[j % 8], byte);
139+
140+
if j >= SHA256_CHUNK_SIZE_BYTES - SHA256_INPUT_LENGTH_BYTE_SIZE {
141+
// If in last chunk, this is a length byte.
142+
byte = self.select(
143+
is_last_chunk,
144+
length_bytes[j % SHA256_INPUT_LENGTH_BYTE_SIZE],
145+
byte,
146+
);
121147
}
122148

123149
padded_bytes.push(byte);
124150
}
125151
}
152+
// These checks verify input_byte_length <= input.len().
153+
self.is_equal(message_byte_selector, false_t);
154+
self.is_equal(reached_last_chunk, true_t);
126155

127-
// self.watch_slice(&padded_bytes, "padded bytes");
128-
129-
assert_eq!(padded_bytes.len(), max_number_of_chunks * 64);
130156
padded_bytes
131157
}
132158
}

plonky2x/core/src/frontend/hash/sha/sha512/curta.rs

+22-6
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,11 @@ impl<L: PlonkParameters<D>, const D: usize> CircuitBuilder<L, D> {
152152
input: &[ByteVariable],
153153
length: U32Variable,
154154
) -> BytesVariable<64> {
155+
// Check that length <= input.len(). This is needed to ensure that users cannot prove the
156+
// hash of a longer message than they supplied.
157+
let supplied_input_length = self.constant::<U32Variable>(input.len() as u32);
158+
self.lte(length, supplied_input_length);
159+
155160
let last_chunk = self.compute_sha512_last_chunk(length);
156161

157162
if self.sha512_accelerator.is_none() {
@@ -178,14 +183,14 @@ impl<L: PlonkParameters<D>, const D: usize> CircuitBuilder<L, D> {
178183

179184
#[cfg(test)]
180185
mod tests {
186+
use std::env;
187+
181188
use rand::{thread_rng, Rng};
182189

183190
use crate::prelude::*;
184191
use crate::utils::hash::sha512;
185-
use crate::utils::setup_logger;
186192

187193
fn test_sha512_fixed(msg: &[u8], expected_digest: [u8; 64]) {
188-
setup_logger();
189194
let mut builder = DefaultBuilder::new();
190195
let message = msg
191196
.iter()
@@ -203,7 +208,6 @@ mod tests {
203208
}
204209

205210
fn test_sha512_variable_length(message: &[u8], input_length: u32, expected_digest: [u8; 64]) {
206-
setup_logger();
207211
let mut builder = DefaultBuilder::new();
208212

209213
let input_length = builder.constant::<U32Variable>(input_length);
@@ -268,9 +272,14 @@ mod tests {
268272
test_sha512_variable_length(&msg, 0, expected_digest);
269273
}
270274

275+
// FAILED
271276
#[test]
272277
#[cfg_attr(feature = "ci", ignore)]
273278
fn test_sha512_curta_variable_large_message() {
279+
env::set_var("RUST_LOG", "debug");
280+
env_logger::try_init().unwrap_or_default();
281+
dotenv::dotenv().ok();
282+
274283
let mut msg : Vec<u8> = bytes!("35c323757c20640a294345c89c0bfcebe3d554fdb0c7b7a0bdb72222c531b1ecf7ec1c43f4de9d49556de87b86b26a98942cb078486fdb44de38b80864c3973153756363696e6374204c616273");
275284
let len = msg.len() as u32;
276285
msg.resize(256, 1);
@@ -279,9 +288,14 @@ mod tests {
279288
test_sha512_variable_length(&msg, len, expected_digest);
280289
}
281290

291+
// FAILED
282292
#[test]
283293
#[cfg_attr(feature = "ci", ignore)]
284294
fn test_sha512_curta_variable_short_message_same_slice() {
295+
env::set_var("RUST_LOG", "debug");
296+
env_logger::try_init().unwrap_or_default();
297+
dotenv::dotenv().ok();
298+
285299
let mut msg: Vec<u8> = b"plonky2".to_vec();
286300
let len = msg.len() as u32;
287301
msg.resize(128, 1);
@@ -290,9 +304,14 @@ mod tests {
290304
test_sha512_variable_length(&msg, len, expected_digest);
291305
}
292306

307+
// FAILED
293308
#[test]
294309
#[cfg_attr(feature = "ci", ignore)]
295310
fn test_sha512_curta_variable_short_message_different_slice() {
311+
env::set_var("RUST_LOG", "debug");
312+
env_logger::try_init().unwrap_or_default();
313+
dotenv::dotenv().ok();
314+
296315
let mut msg: Vec<u8> = b"plonky2".to_vec();
297316
let len = msg.len() as u32;
298317
msg.resize(512, 1);
@@ -304,7 +323,6 @@ mod tests {
304323
#[test]
305324
#[cfg_attr(feature = "ci", ignore)]
306325
fn test_sha512_fixed_length() {
307-
setup_logger();
308326
let mut builder = DefaultBuilder::new();
309327

310328
let max_len = 300;
@@ -331,7 +349,6 @@ mod tests {
331349
#[test]
332350
#[cfg_attr(feature = "ci", ignore)]
333351
fn test_sha512_variable_length_random() {
334-
setup_logger();
335352
let mut builder = DefaultBuilder::new();
336353

337354
let max_number_of_chunks = 2;
@@ -367,7 +384,6 @@ mod tests {
367384
fn test_sha512_variable_length_max_size() {
368385
// This test checks that sha512_variable_pad works as intended, especially when the max
369386
// input length is (length % 128 > 128 - 17).
370-
setup_logger();
371387
let mut builder = DefaultBuilder::new();
372388

373389
let max_number_of_chunks = 1;

0 commit comments

Comments
 (0)