diff --git a/contracts/lib/LibKeccak.sol b/contracts/lib/LibKeccak.sol index ee25454..5d4508d 100644 --- a/contracts/lib/LibKeccak.sol +++ b/contracts/lib/LibKeccak.sol @@ -308,4 +308,54 @@ library LibKeccak { mstore(0x40, add(padded_, and(add(mload(padded_), 0x3F), not(0x1F)))) } } + + /// @notice Pads input data to an even multiple of the Keccak-f[1600] permutation block size, 1088 bits (136 bytes). + /// @dev Can clobber memory after `_data` if `_data` is not already a multiple of 136 bytes. + function padMemory(bytes memory _data) internal pure returns (bytes memory padded_) { + assembly { + padded_ := mload(0x40) + + // Grab the original length of `_data` + let len := mload(_data) + + let dataPtr := add(padded_, 0x20) + let endPtr := add(dataPtr, len) + + // Copy the data. + let originalDataPtr := add(_data, 0x20) + for { let i := 0 } lt(i, len) { i := add(i, 0x20) } { + mstore(add(dataPtr, i), mload(add(originalDataPtr, i))) + } + + let modBlockSize := mod(len, BLOCK_SIZE_BYTES) + switch modBlockSize + case false { + // If the input is a perfect multiple of the block size, then we add a full extra block of padding. + mstore8(endPtr, 0x01) + mstore8(sub(add(endPtr, BLOCK_SIZE_BYTES), 0x01), 0x80) + + // Update the length of the data to include the padding. + mstore(padded_, add(len, BLOCK_SIZE_BYTES)) + } + default { + // If the input is not a perfect multiple of the block size, then we add a partial block of padding. + // This should entail a set bit after the input, followed by as many zero bits as necessary to fill + // the block, followed by a single 1 bit in the lowest-order bit of the final block. + + let remaining := sub(BLOCK_SIZE_BYTES, modBlockSize) + let newLen := add(len, remaining) + + // Store the padding bits. + mstore8(add(dataPtr, sub(newLen, 0x01)), 0x80) + mstore8(endPtr, or(byte(0, mload(endPtr)), 0x01)) + + // Update the length of the data to include the padding. The length should be a multiple of the + // block size after this. + mstore(padded_, newLen) + } + + // Update the free memory pointer. + mstore(0x40, add(padded_, and(add(mload(padded_), 0x3F), not(0x1F)))) + } + } }