1
1
use array_macro:: array;
2
+ use plonky2:: util:: ceil_div_usize;
2
3
3
4
use crate :: prelude:: * ;
4
5
6
+ pub const SHA256_CHUNK_SIZE_BYTES : usize = 64 ;
7
+ pub const SHA256_INPUT_LENGTH_BYTE_SIZE : usize = 8 ;
8
+
9
+ pub const SHA256_CHUNK_SIZE_BITS : usize = SHA256_CHUNK_SIZE_BYTES * 8 ;
10
+ pub const SHA256_INPUT_LENGTH_BIT_SIZE : usize = SHA256_INPUT_LENGTH_BYTE_SIZE * 8 ;
11
+
5
12
impl < L : PlonkParameters < D > , const D : usize > CircuitBuilder < L , D > {
6
13
/// Pad the given input according to the SHA-256 spec.
7
14
/// The last chunk (each chunk is 64 bytes = 512 bits) gets padded.
@@ -13,7 +20,8 @@ impl<L: PlonkParameters<D>, const D: usize> CircuitBuilder<L, D> {
13
20
bits. push ( self . api . _true ( ) ) ;
14
21
15
22
let l = bits. len ( ) - 1 ;
16
- let k = 512 - ( l + 1 + 64 ) % 512 ;
23
+ let k = SHA256_CHUNK_SIZE_BITS
24
+ - ( l + 1 + SHA256_INPUT_LENGTH_BIT_SIZE ) % SHA256_CHUNK_SIZE_BITS ;
17
25
for _ in 0 ..k {
18
26
bits. push ( self . api . _false ( ) ) ;
19
27
}
@@ -51,53 +59,66 @@ impl<L: PlonkParameters<D>, const D: usize> CircuitBuilder<L, D> {
51
59
}
52
60
53
61
/// Pad the given variable length input according to the SHA-256 spec.
54
- ///
55
- /// It is assumed that `input` has length MAX_NUM_CHUNKS * 64.
56
- /// The true number of non-zero bytes in `input` is given by input_byte_length.
57
- pub ( crate ) fn pad_message_sha256_variable (
62
+ /// input_byte_length gives the real length of the input in bytes.
63
+ pub ( crate ) fn pad_sha256_variable_length (
58
64
& mut self ,
59
65
input : & [ ByteVariable ] ,
60
66
input_byte_length : U32Variable ,
61
67
) -> Vec < ByteVariable > {
62
- let max_number_of_chunks = input. len ( ) / 64 ;
63
- assert_eq ! (
64
- max_number_of_chunks * 64 ,
65
- input. len( ) ,
66
- "input length must be a multiple of 64 bytes"
67
- ) ;
68
+ let true_t = self . _true ( ) ;
69
+ let false_t = self . _false ( ) ;
70
+
68
71
let last_chunk = self . compute_sha256_last_chunk ( input_byte_length) ;
69
72
73
+ // Calculate the number of chunks needed to store the input. 9 bytes are added by the
74
+ // padding and LE length representation.
75
+ let max_num_chunks = ceil_div_usize (
76
+ input. len ( ) + SHA256_INPUT_LENGTH_BYTE_SIZE + 1 ,
77
+ SHA256_CHUNK_SIZE_BYTES ,
78
+ ) ;
79
+
80
+ // Extend input to size max_num_chunks * 64 before padding.
81
+ let mut padded_input = input. to_vec ( ) ;
82
+ padded_input. resize ( max_num_chunks * SHA256_CHUNK_SIZE_BYTES , self . zero ( ) ) ;
83
+
70
84
// Compute the length bytes (big-endian representation of the length in bits).
71
85
let zero_byte = self . constant :: < ByteVariable > ( 0x00 ) ;
72
- let mut length_bytes = vec ! [ zero_byte; 4 ] ;
73
86
74
87
let bits_per_byte = self . constant :: < U32Variable > ( 8 ) ;
75
88
let input_bit_length = self . mul ( input_byte_length, bits_per_byte) ;
76
89
77
- let mut length_bits = self . to_le_bits ( input_bit_length) ;
90
+ // Get the length bits in LE order, padded to 64 bits.
91
+ let mut length_bits = self
92
+ . api
93
+ . split_le ( input_bit_length. variable . 0 , SHA256_INPUT_LENGTH_BIT_SIZE ) ;
94
+ // Convert length to BE bits
78
95
length_bits. reverse ( ) ;
79
96
80
- // Prepend 4 zero bytes to length_bytes as abi.encodePacked(U32Variable) is 4 bytes.
81
- length_bytes. extend_from_slice (
82
- & length_bits
83
- . chunks ( 8 )
84
- . map ( |chunk| {
85
- let bits = array ! [ x => chunk[ x] ; 8 ] ;
86
- ByteVariable ( bits)
87
- } )
88
- . collect :: < Vec < _ > > ( ) ,
89
- ) ;
97
+ let length_bytes = & length_bits
98
+ . chunks ( 8 )
99
+ . map ( |chunk| {
100
+ let bits = array ! [ x => BoolVariable :: from_targets( & [ chunk[ x] . target] ) ; 8 ] ;
101
+ ByteVariable ( bits)
102
+ } )
103
+ . collect :: < Vec < _ > > ( ) ;
90
104
91
105
let mut padded_bytes = Vec :: new ( ) ;
92
106
93
- let mut message_byte_selector = self . constant :: < BoolVariable > ( true ) ;
94
- for i in 0 ..max_number_of_chunks {
95
- let chunk_offset = 64 * i;
107
+ // Set to true if the last chunk has been reached. This is used to verify that
108
+ // input_byte_length is <= input.len().
109
+ let mut reached_last_chunk = false_t;
110
+
111
+ let mut message_byte_selector = true_t;
112
+ for i in 0 ..max_num_chunks {
113
+ let chunk_offset = SHA256_CHUNK_SIZE_BYTES * i;
96
114
let curr_chunk = self . constant :: < U32Variable > ( i as u32 ) ;
97
115
116
+ // Check if this is the chunk where length should be added.
98
117
let is_last_chunk = self . is_equal ( curr_chunk, last_chunk) ;
118
+ reached_last_chunk = self . or ( reached_last_chunk, is_last_chunk) ;
99
119
100
- for j in 0 ..64 {
120
+ for j in 0 ..SHA256_CHUNK_SIZE_BYTES {
121
+ // First 64 - 8 bytes are either message | padding | nil bytes.
101
122
let idx = chunk_offset + j;
102
123
let idx_t = self . constant :: < U32Variable > ( idx as u32 ) ;
103
124
let is_last_msg_byte = self . is_equal ( idx_t, input_byte_length) ;
@@ -112,21 +133,26 @@ impl<L: PlonkParameters<D>, const D: usize> CircuitBuilder<L, D> {
112
133
let padding_start_byte = self . constant :: < ByteVariable > ( 0x80 ) ;
113
134
114
135
// If message_byte_selector is true, select the message byte.
115
- let mut byte = self . select ( message_byte_selector, input [ idx] , zero_byte) ;
136
+ let mut byte = self . select ( message_byte_selector, padded_input [ idx] , zero_byte) ;
116
137
// If idx == length_bytes, select the padding start byte.
117
138
byte = self . select ( is_last_msg_byte, padding_start_byte, byte) ;
118
- if j >= 64 - 8 {
119
- // If in last chunk, select the length byte.
120
- byte = self . select ( is_last_chunk, length_bytes[ j % 8 ] , byte) ;
139
+
140
+ if j >= SHA256_CHUNK_SIZE_BYTES - SHA256_INPUT_LENGTH_BYTE_SIZE {
141
+ // If in last chunk, this is a length byte.
142
+ byte = self . select (
143
+ is_last_chunk,
144
+ length_bytes[ j % SHA256_INPUT_LENGTH_BYTE_SIZE ] ,
145
+ byte,
146
+ ) ;
121
147
}
122
148
123
149
padded_bytes. push ( byte) ;
124
150
}
125
151
}
152
+ // These checks verify input_byte_length <= input.len().
153
+ self . is_equal ( message_byte_selector, false_t) ;
154
+ self . is_equal ( reached_last_chunk, true_t) ;
126
155
127
- // self.watch_slice(&padded_bytes, "padded bytes");
128
-
129
- assert_eq ! ( padded_bytes. len( ) , max_number_of_chunks * 64 ) ;
130
156
padded_bytes
131
157
}
132
158
}
0 commit comments