diff --git a/bench/speed.ml b/bench/speed.ml index 5ad317cc..56801fa1 100644 --- a/bench/speed.ml +++ b/bench/speed.ml @@ -26,17 +26,15 @@ let burn_period = 2.0 let sizes = [16; 64; 256; 1024; 8192] (* let sizes = [16] *) -let big_b = Bytes.create List.(hd (rev sizes)) - let burn f n = - let cs = Cstruct.of_string (Mirage_crypto_rng.generate n) in + let buf = Mirage_crypto_rng.generate n in let (t1, i1) = let rec loop it = - let t = Time.time ~n:it f cs in + let t = Time.time ~n:it f buf in if t > 0.2 then (t, it) else loop (it * 10) in loop 10 in let iters = int_of_float (float i1 *. burn_period /. t1) in - let time = Time.time ~n:iters f cs in + let time = Time.time ~n:iters f buf in (iters, time, float (n * iters) /. time) let mb = 1024. *. 1024. @@ -67,13 +65,7 @@ let count title f to_str args = Printf.printf " %s: %.03f ops per second (%d iters in %.03f)\n%!" (to_str arg) (float iters /. time) iters time -let msg = - let b = Cstruct.create 100 in - Cstruct.memset b 0xAA; - b - -let msg_str = - Cstruct.to_string msg +let msg_str = String.make 100 '\xAA' let msg_str_32 = String.sub msg_str 0 32 let msg_str_48 = String.sub msg_str 0 48 @@ -357,62 +349,63 @@ let benchmarks = [ fst ecdh_shares); bm "chacha20-poly1305" (fun name -> - let key = Mirage_crypto.Chacha20.of_secret (Cstruct.of_string (Mirage_crypto_rng.generate 32)) - and nonce = Cstruct.of_string (Mirage_crypto_rng.generate 8) in + let key = Mirage_crypto.Chacha20.of_secret (Mirage_crypto_rng.generate 32) + and nonce = Mirage_crypto_rng.generate 8 in throughput name (Mirage_crypto.Chacha20.authenticate_encrypt ~key ~nonce)) ; bm "aes-128-ecb" (fun name -> - let key = AES.ECB.of_secret (Cstruct.of_string (Mirage_crypto_rng.generate 16)) in + let key = AES.ECB.of_secret (Mirage_crypto_rng.generate 16) in throughput name (fun cs -> AES.ECB.encrypt ~key cs)) ; bm "aes-128-cbc-e" (fun name -> - let key = AES.CBC.of_secret (Cstruct.of_string (Mirage_crypto_rng.generate 16)) - and iv = Cstruct.of_string (Mirage_crypto_rng.generate 16) in + let key = AES.CBC.of_secret (Mirage_crypto_rng.generate 16) + and iv = Mirage_crypto_rng.generate 16 in throughput name (fun cs -> AES.CBC.encrypt ~key ~iv cs)) ; bm "aes-128-cbc-d" (fun name -> - let key = AES.CBC.of_secret (Cstruct.of_string (Mirage_crypto_rng.generate 16)) - and iv = Cstruct.of_string (Mirage_crypto_rng.generate 16) in + let key = AES.CBC.of_secret (Mirage_crypto_rng.generate 16) + and iv = Mirage_crypto_rng.generate 16 in throughput name (fun cs -> AES.CBC.decrypt ~key ~iv cs)) ; bm "aes-128-ctr" (fun name -> - let key = Mirage_crypto_rng.generate 16 |> Cstruct.of_string |> AES.CTR.of_secret - and ctr = Mirage_crypto_rng.generate 16 |> Cstruct.of_string |> AES.CTR.ctr_of_cstruct in + let key = Mirage_crypto_rng.generate 16 |> AES.CTR.of_secret + and ctr = Mirage_crypto_rng.generate 16 |> AES.CTR.ctr_of_octets in throughput name (fun cs -> AES.CTR.encrypt ~key ~ctr cs)) ; bm "aes-128-gcm" (fun name -> - let key = AES.GCM.of_secret (Cstruct.of_string (Mirage_crypto_rng.generate 16)) - and nonce = Cstruct.of_string (Mirage_crypto_rng.generate 12) in + let key = AES.GCM.of_secret (Mirage_crypto_rng.generate 16) + and nonce = Mirage_crypto_rng.generate 12 in throughput name (fun cs -> AES.GCM.authenticate_encrypt ~key ~nonce cs)); bm "aes-128-ghash" (fun name -> - let key = AES.GCM.of_secret (Cstruct.of_string (Mirage_crypto_rng.generate 16)) - and nonce = Cstruct.of_string (Mirage_crypto_rng.generate 12) in - throughput name (fun cs -> AES.GCM.authenticate_encrypt ~key ~nonce ~adata:cs Cstruct.empty)); + let key = AES.GCM.of_secret (Mirage_crypto_rng.generate 16) + and nonce = Mirage_crypto_rng.generate 12 in + throughput name (fun cs -> AES.GCM.authenticate_encrypt ~key ~nonce ~adata:cs "")); bm "aes-128-ccm" (fun name -> - let key = AES.CCM16.of_secret (Cstruct.of_string (Mirage_crypto_rng.generate 16)) - and nonce = Cstruct.of_string (Mirage_crypto_rng.generate 10) in + let key = AES.CCM16.of_secret (Mirage_crypto_rng.generate 16) + and nonce = Mirage_crypto_rng.generate 10 in throughput name (fun cs -> AES.CCM16.authenticate_encrypt ~key ~nonce cs)); bm "aes-192-ecb" (fun name -> - let key = AES.ECB.of_secret (Cstruct.of_string (Mirage_crypto_rng.generate 24)) in + let key = AES.ECB.of_secret (Mirage_crypto_rng.generate 24) in throughput name (fun cs -> AES.ECB.encrypt ~key cs)) ; bm "aes-256-ecb" (fun name -> - let key = AES.ECB.of_secret (Cstruct.of_string (Mirage_crypto_rng.generate 32)) in + let key = AES.ECB.of_secret (Mirage_crypto_rng.generate 32) in throughput name (fun cs -> AES.ECB.encrypt ~key cs)) ; bm "d3des-ecb" (fun name -> - let key = DES.ECB.of_secret (Cstruct.of_string (Mirage_crypto_rng.generate 24)) in + let key = DES.ECB.of_secret (Mirage_crypto_rng.generate 24) in throughput name (fun cs -> DES.ECB.encrypt ~key cs)) ; bm "fortuna" (fun name -> let open Mirage_crypto_rng.Fortuna in let g = create () in reseed ~g "abcd" ; - throughput name (fun cs -> - generate_into ~g big_b ~off:0 (Cstruct.length cs))) ; + throughput name (fun buf -> + let buf = Bytes.unsafe_of_string buf in + generate_into ~g buf ~off:0 (Bytes.length buf))) ; ] let help () = diff --git a/mirage-crypto-ec.opam b/mirage-crypto-ec.opam index a248dafe..5acd4e3d 100644 --- a/mirage-crypto-ec.opam +++ b/mirage-crypto-ec.opam @@ -31,7 +31,6 @@ depends: [ "eqaf" {>= "0.7"} "mirage-crypto-rng" {=version} "digestif" {>= "1.2.0"} - "hex" {with-test} "alcotest" {with-test & >= "0.8.1"} "ppx_deriving_yojson" {with-test} "ppx_deriving" {with-test} diff --git a/mirage-crypto-rng-async.opam b/mirage-crypto-rng-async.opam index 890c6034..ac01c562 100644 --- a/mirage-crypto-rng-async.opam +++ b/mirage-crypto-rng-async.opam @@ -19,6 +19,7 @@ depends: [ "async" {>= "v0.14"} "logs" "mirage-crypto-rng" {=version} + "ohex" {with-test & >= "0.2.0"} ] available: os != "win32" description: """ diff --git a/mirage-crypto-rng-eio.opam b/mirage-crypto-rng-eio.opam index c1fb8b33..6348bb0e 100644 --- a/mirage-crypto-rng-eio.opam +++ b/mirage-crypto-rng-eio.opam @@ -21,6 +21,7 @@ depends: [ "duration" "mtime" "eio_main" {with-test} + "ohex" {with-test & >= "0.2.0"} ] description: """ Mirage-crypto-rng-eio feeds the entropy source for Mirage_crypto_rng-based diff --git a/mirage-crypto-rng-mirage.opam b/mirage-crypto-rng-mirage.opam index b22480ef..5f99c6bf 100644 --- a/mirage-crypto-rng-mirage.opam +++ b/mirage-crypto-rng-mirage.opam @@ -25,6 +25,7 @@ depends: [ "mirage-unix" {with-test & >= "5.0.0"} "mirage-time-unix" {with-test & >= "2.0.0"} "mirage-clock-unix" {with-test & >= "3.0.0"} + "ohex" {with-test & >= "0.2.0"} ] description: """ Mirage-crypto-rng-mirage provides entropy collection code for the RNG. diff --git a/mirage-crypto-rng.opam b/mirage-crypto-rng.opam index 3b54dde8..701c88cc 100644 --- a/mirage-crypto-rng.opam +++ b/mirage-crypto-rng.opam @@ -22,6 +22,7 @@ depends: [ "digestif" {>= "1.1.4"} "ounit2" {with-test} "randomconv" {with-test & >= "0.2.0"} + "ohex" {with-test & >= "0.2.0"} ] conflicts: [ "mirage-runtime" {< "3.8.0"} ] description: """ diff --git a/mirage-crypto.opam b/mirage-crypto.opam index 1f246383..484cda49 100644 --- a/mirage-crypto.opam +++ b/mirage-crypto.opam @@ -17,7 +17,7 @@ depends: [ "dune" {>= "2.7"} "dune-configurator" {>= "2.0.0"} "ounit2" {with-test} - "cstruct" {>="6.0.0"} + "ohex" {with-test & >= "0.2.0"} "eqaf" {>= "0.8"} ] conflicts: [ diff --git a/rng/fortuna.ml b/rng/fortuna.ml index 6cd696df..1930682d 100644 --- a/rng/fortuna.ml +++ b/rng/fortuna.ml @@ -39,7 +39,7 @@ let create ?time () = let k = String.make 32 '\x00' in { ctr = (0L, 0L) ; secret = k - ; key = AES_CTR.of_secret (Cstruct.of_string k) + ; key = AES_CTR.of_secret k ; pools = Array.make pools SHAd256.empty ; pool0_size = 0 ; reseed_count = 0 @@ -54,7 +54,7 @@ let seeded ~g = (* XXX We might want to erase the old key. *) let set_key ~g sec = g.secret <- sec ; - g.key <- AES_CTR.of_secret (Cstruct.of_string sec) + g.key <- AES_CTR.of_secret sec let reseedi ~g iter = set_key ~g @@ SHAd256.digesti (fun f -> f g.secret; iter f); @@ -67,7 +67,7 @@ let reseed ~g cs = reseedi ~g (iter1 cs) let generate_rekey ~g buf ~off len = let b = len // block + 2 in let n = b * block in - let r = Cstruct.to_string (AES_CTR.stream ~key:g.key ~ctr:g.ctr n) in + let r = AES_CTR.stream ~key:g.key ~ctr:g.ctr n in Bytes.blit_string r 0 buf off len; let r2 = String.sub r (n - 32) 32 in set_key ~g r2 ; diff --git a/src/aead.ml b/src/aead.ml index c75fd10d..a03214e1 100644 --- a/src/aead.ml +++ b/src/aead.ml @@ -1,13 +1,13 @@ module type AEAD = sig val tag_size : int type key - val of_secret : Cstruct.t -> key - val authenticate_encrypt : key:key -> nonce:Cstruct.t -> ?adata:Cstruct.t -> - Cstruct.t -> Cstruct.t - val authenticate_decrypt : key:key -> nonce:Cstruct.t -> ?adata:Cstruct.t -> - Cstruct.t -> Cstruct.t option - val authenticate_encrypt_tag : key:key -> nonce:Cstruct.t -> - ?adata:Cstruct.t -> Cstruct.t -> Cstruct.t * Cstruct.t - val authenticate_decrypt_tag : key:key -> nonce:Cstruct.t -> ?adata:Cstruct.t -> - tag:Cstruct.t -> Cstruct.t -> Cstruct.t option + val of_secret : string -> key + val authenticate_encrypt : key:key -> nonce:string -> ?adata:string -> + string -> string + val authenticate_decrypt : key:key -> nonce:string -> ?adata:string -> + string -> string option + val authenticate_encrypt_tag : key:key -> nonce:string -> ?adata:string -> + string -> string * string + val authenticate_decrypt_tag : key:key -> nonce:string -> ?adata:string -> + tag:string -> string -> string option end diff --git a/src/ccm.ml b/src/ccm.ml index 7e7c4af9..746c02d5 100644 --- a/src/ccm.ml +++ b/src/ccm.ml @@ -1,161 +1,157 @@ open Uncommon -let (<+>) = Cs.(<+>) - let block_size = 16 let flags bit6 len1 len2 = - let byte = Cstruct.create 1 - and data = bit6 lsl 6 + len1 lsl 3 + len2 in - Cstruct.set_uint8 byte 0 data ; - byte + bit6 lsl 6 + len1 lsl 3 + len2 -let encode_len_buf size value buf = +let encode_len buf ~off size value = let rec ass num = function - | 0 -> Cstruct.set_uint8 buf 0 num - | m -> Cstruct.set_uint8 buf m (num land 0xff) ; ass (num lsr 8) (pred m) + | 0 -> Bytes.set_uint8 buf off num + | m -> + Bytes.set_uint8 buf (off + m) (num land 0xff); + ass (num lsr 8) (pred m) in ass value (pred size) -let encode_len size value = - let b = Cstruct.create size in - encode_len_buf size value b ; - b - -let format nonce adata q t (* mac len *) = - (* assume n <- [7..13] *) - (* assume t is valid mac size *) - (* n + q = 15 *) - (* a < 2 ^ 64 *) - let n = Cstruct.length nonce in +let set_format buf ?(off = 0) nonce flag_val value = + let n = String.length nonce in let small_q = 15 - n in - (* first byte (flags): *) - (* reserved | adata | (t - 2) / 2 | q - 1 *) - let b6 = if Cstruct.length adata = 0 then 0 else 1 in - let flag = flags b6 ((t - 2) / 2) (small_q - 1) in (* first octet block: 0 : flags 1..15 - q : N 16 - q..15 : Q *) - let qblock = encode_len small_q q in - flag <+> nonce <+> qblock - -let pad_block b = - let size = Cstruct.length b in - Cs.rpad b (size // block_size * block_size) 0 + Bytes.set_uint8 buf off flag_val; + Bytes.unsafe_blit_string nonce 0 buf (off + 1) n; + encode_len buf ~off:(off + n + 1) small_q value let gen_adata a = - let lbuf = - match Cstruct.length a with + let llen, set_llen = + match String.length a with | x when x < (1 lsl 16 - 1 lsl 8) -> - let buf = Cstruct.create 2 in - Cstruct.BE.set_uint16 buf 0 x ; - buf + 2, (fun buf off -> Bytes.set_uint16_be buf off x) | x when Sys.int_size < 32 || x < (1 lsl 32) -> - let buf = Cstruct.create 4 in - Cstruct.BE.set_uint32 buf 0 (Int32.of_int x) ; - Cs.of_bytes [0xff ; 0xfe] <+> buf + 6, (fun buf off -> + Bytes.set_uint16_be buf off 0xfffe; + Bytes.set_int32_be buf (off + 2) (Int32.of_int x)) | x -> - let buf = Cstruct.create 8 in - Cstruct.BE.set_uint64 buf 0 (Int64.of_int x) ; - Cs.of_bytes [0xff ; 0xff] <+> buf + 10, (fun buf off -> + Bytes.set_uint16_be buf off 0xffff; + Bytes.set_int64_be buf (off + 2) (Int64.of_int x)) in - pad_block (lbuf <+> a) - -let gen_ctr_prefix nonce = - let n = Cstruct.length nonce in - let small_q = 15 - n in - let flag = flags 0 0 (small_q - 1) in - (flag <+> nonce, succ n, small_q) + let to_pad = + let leftover = (llen + String.length a) mod block_size in + block_size - leftover + in + llen + String.length a + to_pad, + fun buf off -> + set_llen buf off; + Bytes.blit_string a 0 buf (off + llen) (String.length a); + Bytes.fill buf (off + llen + String.length a) to_pad '\000' let gen_ctr nonce i = - let pre, _, q = gen_ctr_prefix nonce in - pre <+> encode_len q i + let n = String.length nonce in + let small_q = 15 - n in + let flag_val = flags 0 0 (small_q - 1) in + let buf = Bytes.create 16 in + set_format buf nonce flag_val i; + buf let prepare_header nonce adata plen tlen = - let ada = if Cstruct.length adata = 0 then Cstruct.empty else gen_adata adata in - format nonce adata plen tlen <+> ada + let small_q = 15 - String.length nonce in + let b6 = if String.length adata = 0 then 0 else 1 in + let flag_val = flags b6 ((tlen - 2) / 2) (small_q - 1) in + if String.length adata = 0 then + let hdr = Bytes.create 16 in + set_format hdr nonce flag_val plen; + hdr + else + let len, set = gen_adata adata in + let buf = Bytes.create (16 + len) in + set_format buf nonce flag_val plen; + set buf 16; + buf type mode = Encrypt | Decrypt let crypto_core ~cipher ~mode ~key ~nonce ~maclen ~adata data = - let datalen = Cstruct.length data in + let datalen = String.length data in let cbcheader = prepare_header nonce adata datalen maclen in - let target = Cstruct.create datalen in + let dst = Bytes.create datalen in - let blkprefix, blkpreflen, preflen = gen_ctr_prefix nonce in + let small_q = 15 - String.length nonce in + let ctr_flag_val = flags 0 0 (small_q - 1) in let ctrblock i block = - Cstruct.blit blkprefix 0 block 0 blkpreflen ; - encode_len_buf preflen i (Cstruct.shift block blkpreflen) ; - cipher ~key block block + Bytes.set_uint8 block 0 ctr_flag_val; + Bytes.unsafe_blit_string nonce 0 block 1 (String.length nonce); + encode_len block ~off:(String.length nonce + 1) small_q i; + cipher ~key (Bytes.unsafe_to_string block) ~src_off:0 block ~dst_off:0 in - let cbc iv block = - Cs.xor_into iv block block_size ; - cipher ~key block block + let cbc iv src_off block dst_off = + xor_into iv ~src_off block ~dst_off block_size ; + cipher ~key (Bytes.unsafe_to_string block) ~src_off:dst_off block ~dst_off in let cbcprep = - let rec doit iv block = - match Cstruct.length block with - | 0 -> iv + let rec doit iv iv_off block block_off = + match Bytes.length block - block_off with + | 0 -> Bytes.sub iv iv_off block_size | _ -> - cbc iv block ; - doit (Cstruct.sub block 0 block_size) - (Cstruct.shift block block_size) + cbc (Bytes.unsafe_to_string iv) iv_off block block_off; + doit block block_off block (block_off + block_size) in - doit (Cstruct.create block_size) cbcheader + doit (Bytes.make block_size '\x00') 0 cbcheader 0 in - let rec loop iv ctr src target = - let cbcblock = + let rec loop iv ctr src src_off dst dst_off= + let cbcblock, cbc_off = match mode with - | Encrypt -> src - | Decrypt -> target + | Encrypt -> src, src_off + | Decrypt -> Bytes.unsafe_to_string dst, dst_off in - match Cstruct.length src with + match String.length src - src_off with | 0 -> iv | x when x < block_size -> - let ctrbl = pad_block target in - ctrblock ctr ctrbl ; - Cstruct.blit ctrbl 0 target 0 x ; - Cs.xor_into src target x ; - let cbblock = pad_block cbcblock in - cbc cbblock iv ; - iv + let buf = Bytes.make block_size '\x00' in + Bytes.unsafe_blit dst dst_off buf 0 x; + ctrblock ctr buf ; + Bytes.unsafe_blit buf 0 dst dst_off x ; + xor_into src ~src_off dst ~dst_off x ; + Bytes.unsafe_blit_string cbcblock cbc_off buf 0 x; + Bytes.unsafe_fill buf x (block_size - x) '\x00'; + cbc (Bytes.unsafe_to_string buf) cbc_off iv 0 ; + iv | _ -> - ctrblock ctr target ; - Cs.xor_into src target block_size ; - cbc cbcblock iv ; - loop iv - (succ ctr) - (Cstruct.shift src block_size) - (Cstruct.shift target block_size) + ctrblock ctr dst ; + xor_into src ~src_off dst ~dst_off block_size ; + cbc cbcblock cbc_off iv 0 ; + loop iv (succ ctr) src (src_off + block_size) dst (dst_off + block_size) in - let last = loop cbcprep 1 data target in - let t = Cstruct.sub last 0 maclen in - (target, t) + let last = loop cbcprep 1 data 0 dst 0 in + let t = Bytes.sub last 0 maclen in + (dst, t) let crypto_t t nonce cipher key = let ctr = gen_ctr nonce 0 in - cipher ~key ctr ctr ; - Cs.xor_into ctr t (Cstruct.length t) + cipher ~key (Bytes.unsafe_to_string ctr) ~src_off:0 ctr ~dst_off:0 ; + xor_into (Bytes.unsafe_to_string ctr) t (Bytes.length t) let valid_nonce nonce = - let nsize = Cstruct.length nonce in + let nsize = String.length nonce in if nsize < 7 || nsize > 13 then - invalid_arg "CCM: nonce length not between 7 and 13: %d" nsize + invalid_arg "CCM: nonce length not between 7 and 13: %u" nsize let generation_encryption ~cipher ~key ~nonce ~maclen ~adata data = valid_nonce nonce; let cdata, t = crypto_core ~cipher ~mode:Encrypt ~key ~nonce ~maclen ~adata data in crypto_t t nonce cipher key ; - cdata, t + Bytes.unsafe_to_string cdata, Bytes.unsafe_to_string t let decryption_verification ~cipher ~key ~nonce ~maclen ~adata ~tag data = valid_nonce nonce; let cdata, t = crypto_core ~cipher ~mode:Decrypt ~key ~nonce ~maclen ~adata data in - crypto_t tag nonce cipher key ; - match Eqaf_cstruct.equal tag t with - | true -> Some cdata + crypto_t t nonce cipher key ; + match Eqaf.equal tag (Bytes.unsafe_to_string t) with + | true -> Some (Bytes.unsafe_to_string cdata) | false -> None diff --git a/src/chacha20.ml b/src/chacha20.ml index 0e58b8cc..f0d97840 100644 --- a/src/chacha20.ml +++ b/src/chacha20.ml @@ -6,7 +6,7 @@ let block = 64 type key = string -let of_secret a = Cstruct.to_string a +let of_secret a = a let chacha20_block state idx key_stream = Native.Chacha.round 10 state key_stream idx @@ -89,33 +89,30 @@ let mac ~key ~adata ciphertext = in P.macl ~key [ adata ; pad16 adata ; ciphertext ; pad16 ciphertext ; len ] -let authenticate_encrypt_tag ~key ~nonce ?(adata = Cstruct.empty) data = - let adata = Cstruct.to_string adata in - let nonce = Cstruct.to_string nonce in - let data = Cstruct.to_string data in +let authenticate_encrypt_tag ~key ~nonce ?(adata = "") data = let poly1305_key = generate_poly1305_key ~key ~nonce in let ciphertext = crypt ~key ~nonce ~ctr:1L data in let mac = mac ~key:poly1305_key ~adata ciphertext in - Cstruct.of_string ciphertext, Cstruct.of_string mac + ciphertext, mac let authenticate_encrypt ~key ~nonce ?adata data = let cdata, ctag = authenticate_encrypt_tag ~key ~nonce ?adata data in - Cstruct.append cdata ctag + cdata ^ ctag -let authenticate_decrypt_tag ~key ~nonce ?(adata = Cstruct.empty) ~tag data = - let adata = Cstruct.to_string adata in - let nonce = Cstruct.to_string nonce in - let data = Cstruct.to_string data in +let authenticate_decrypt_tag ~key ~nonce ?(adata = "") ~tag data = let poly1305_key = generate_poly1305_key ~key ~nonce in let ctag = mac ~key:poly1305_key ~adata data in let plain = crypt ~key ~nonce ~ctr:1L data in - if Eqaf_cstruct.equal tag (Cstruct.of_string ctag) then Some (Cstruct.of_string plain) else None + if Eqaf.equal tag ctag then Some plain else None let authenticate_decrypt ~key ~nonce ?adata data = - if Cstruct.length data < P.mac_size then + if String.length data < P.mac_size then None else - let cipher, tag = Cstruct.split data (Cstruct.length data - P.mac_size) in + let cipher, tag = + let p = String.length data - P.mac_size in + String.sub data 0 p, String.sub data p P.mac_size + in authenticate_decrypt_tag ~key ~nonce ?adata ~tag cipher let tag_size = P.mac_size diff --git a/src/cipher_block.ml b/src/cipher_block.ml index 889826e0..7fd43365 100644 --- a/src/cipher_block.ml +++ b/src/cipher_block.ml @@ -7,59 +7,59 @@ module S = struct type ekey type dkey - val of_secret : Cstruct.t -> ekey * dkey - val e_of_secret : Cstruct.t -> ekey - val d_of_secret : Cstruct.t -> dkey + val of_secret : string -> ekey * dkey + val e_of_secret : string -> ekey + val d_of_secret : string -> dkey val key : int array val block : int (* XXX currently unsafe point *) - val encrypt : key:ekey -> blocks:int -> Native.buffer -> int -> Native.buffer -> int -> unit - val decrypt : key:dkey -> blocks:int -> Native.buffer -> int -> Native.buffer -> int -> unit + val encrypt : key:ekey -> blocks:int -> string -> int -> bytes -> int -> unit + val decrypt : key:dkey -> blocks:int -> string -> int -> bytes -> int -> unit end module type ECB = sig type key - val of_secret : Cstruct.t -> key + val of_secret : string -> key val key_sizes : int array val block_size : int - val encrypt : key:key -> Cstruct.t -> Cstruct.t - val decrypt : key:key -> Cstruct.t -> Cstruct.t + val encrypt : key:key -> string -> string + val decrypt : key:key -> string -> string end module type CBC = sig type key - val of_secret : Cstruct.t -> key + val of_secret : string -> key val key_sizes : int array val block_size : int - val encrypt : key:key -> iv:Cstruct.t -> Cstruct.t -> Cstruct.t - val decrypt : key:key -> iv:Cstruct.t -> Cstruct.t -> Cstruct.t - val next_iv : iv:Cstruct.t -> Cstruct.t -> Cstruct.t + val encrypt : key:key -> iv:string -> string -> string + val decrypt : key:key -> iv:string -> string -> string + val next_iv : iv:string -> string -> string end module type CTR = sig type key - val of_secret : Cstruct.t -> key + val of_secret : string -> key type ctr val key_sizes : int array val block_size : int - val stream : key:key -> ctr:ctr -> int -> Cstruct.t - val encrypt : key:key -> ctr:ctr -> Cstruct.t -> Cstruct.t - val decrypt : key:key -> ctr:ctr -> Cstruct.t -> Cstruct.t + val stream : key:key -> ctr:ctr -> int -> string + val encrypt : key:key -> ctr:ctr -> string -> string + val decrypt : key:key -> ctr:ctr -> string -> string val add_ctr : ctr -> int64 -> ctr - val next_ctr : ctr:ctr -> Cstruct.t -> ctr - val ctr_of_cstruct : Cstruct.t -> ctr + val next_ctr : ctr:ctr -> string -> ctr + val ctr_of_octets : string -> ctr end module type GCM = sig @@ -78,15 +78,12 @@ module S = struct end module Counters = struct - - open Cstruct - module type S = sig type ctr val size : int val add : ctr -> int64 -> ctr - val of_cstruct : Cstruct.t -> ctr - val unsafe_count_into : ctr -> Native.buffer -> int -> blocks:int -> unit + val of_octets : string -> ctr + val unsafe_count_into : ctr -> bytes -> int -> blocks:int -> unit end let _tmp = Bytes.make 16 '\x00' @@ -94,7 +91,8 @@ module Counters = struct module C64be = struct type ctr = int64 let size = 8 - let of_cstruct cs = BE.get_uint64 cs 0 + (* Until OCaml 4.13 is lower bound*) + let of_octets cs = Bytes.get_int64_be (Bytes.unsafe_of_string cs) 0 let add = Int64.add let unsafe_count_into t buf off ~blocks = Bytes.set_int64_be _tmp 0 t; @@ -104,7 +102,9 @@ module Counters = struct module C128be = struct type ctr = int64 * int64 let size = 16 - let of_cstruct cs = BE.(get_uint64 cs 0, get_uint64 cs 8) + let of_octets cs = + let buf = Bytes.unsafe_of_string cs in + Bytes.(get_int64_be buf 0, get_int64_be buf 8) let add (w1, w0) n = let w0' = Int64.add w0 n in let flip = if Int64.logxor w0 w0' < 0L then w0' > w0 else w0' < w0 in @@ -126,9 +126,6 @@ module Counters = struct end module Modes = struct - - open Cstruct - module ECB_of (Core : S.Core) : S.ECB = struct type key = Core.ekey * Core.dkey @@ -139,11 +136,11 @@ module Modes = struct let (encrypt, decrypt) = let ecb xform key src = - let n = src.len in - if n mod block_size <> 0 then invalid_arg "ECB: length %d" n; - let dst = create n in - xform ~key ~blocks:(n / block_size) src.buffer src.off dst.buffer dst.off ; - dst + let n = String.length src in + if n mod block_size <> 0 then invalid_arg "ECB: length %u" n; + let dst = Bytes.create n in + xform ~key ~blocks:(n / block_size) src 0 dst 0 ; + Bytes.unsafe_to_string dst in (fun ~key:(key, _) src -> ecb Core.encrypt key src), (fun ~key:(_, key) src -> ecb Core.decrypt key src) @@ -160,38 +157,38 @@ module Modes = struct let of_secret = Core.of_secret let bounds_check ~iv cs = - if iv.len <> block then invalid_arg "CBC: IV length %d" iv.len; - if cs.len mod block <> 0 then - invalid_arg "CBC: argument length %d" cs.len + if String.length iv <> block then invalid_arg "CBC: IV length %u" (String.length iv); + if String.length cs mod block <> 0 then + invalid_arg "CBC: argument length %u" (String.length cs) let next_iv ~iv cs = bounds_check ~iv cs ; - if cs.len > 0 then - sub cs (cs.len - block_size) block_size + if String.length cs > 0 then + String.sub cs (String.length cs - block_size) block_size else iv let encrypt ~key:(key, _) ~iv src = bounds_check ~iv src ; - let msg = Cs.clone src in - let dst = msg.buffer in + let dst = Bytes.of_string src in let rec loop iv iv_i dst_i = function 0 -> () - | b -> Native.xor_into iv iv_i dst dst_i block ; - Core.encrypt ~key ~blocks:1 dst dst_i dst dst_i ; - loop dst dst_i (dst_i + block) (b - 1) + | b -> Native.xor_into_bytes iv iv_i dst dst_i block ; + Core.encrypt ~key ~blocks:1 (Bytes.unsafe_to_string dst) dst_i dst dst_i ; + loop (Bytes.unsafe_to_string dst) dst_i (dst_i + block) (b - 1) in - loop iv.buffer iv.off msg.off (msg.len / block) ; msg + loop iv 0 0 (Bytes.length dst / block) ; + Bytes.unsafe_to_string dst let decrypt ~key:(_, key) ~iv src = bounds_check ~iv src ; - let msg = create src.len - and b = src.len / block in + let msg = Bytes.create (String.length src) + and b = String.length src / block in if b > 0 then begin - Core.decrypt ~key ~blocks:b src.buffer src.off msg.buffer msg.off ; - Native.xor_into iv.buffer iv.off msg.buffer msg.off block ; - Native.xor_into src.buffer src.off msg.buffer (msg.off + block) ((b - 1) * block) ; + Core.decrypt ~key ~blocks:b src 0 msg 0 ; + Native.xor_into_bytes iv 0 msg 0 block ; + Native.xor_into_bytes src 0 msg block ((b - 1) * block) ; end ; - msg + Bytes.unsafe_to_string msg end @@ -208,43 +205,51 @@ module Modes = struct let of_secret = Core.e_of_secret let stream ~key ~ctr n = - let blocks = imax 0 n // block_size in - let buf = Native.buffer (blocks * block_size) in + let blocks = imax 0 n / block_size in + let buf = Bytes.create n in Ctr.unsafe_count_into ctr ~blocks buf 0 ; - Core.encrypt ~key ~blocks buf 0 buf 0 ; - of_bigarray ~len:n buf + Core.encrypt ~key ~blocks (Bytes.unsafe_to_string buf) 0 buf 0 ; + let slack = imax 0 n mod block_size in + if slack <> 0 then begin + let buf' = Bytes.create block_size in + let ctr = Ctr.add ctr (Int64.of_int blocks) in + Ctr.unsafe_count_into ctr ~blocks:1 buf' 0 ; + Core.encrypt ~key ~blocks:1 (Bytes.unsafe_to_string buf') 0 buf' 0 ; + Bytes.blit buf' 0 buf (blocks * block_size) slack + end; + Bytes.unsafe_to_string buf let encrypt ~key ~ctr src = - let res = stream ~key ~ctr src.len in - Native.xor_into src.buffer src.off res.buffer res.off src.len ; - res + let res = Bytes.unsafe_of_string (stream ~key ~ctr (String.length src)) in + Native.xor_into_bytes src 0 res 0 (String.length src) ; + Bytes.unsafe_to_string res let decrypt = encrypt let add_ctr = Ctr.add - let next_ctr ~ctr msg = add_ctr ctr (Int64.of_int @@ msg.len // block_size) - let ctr_of_cstruct = Ctr.of_cstruct + let next_ctr ~ctr msg = add_ctr ctr (Int64.of_int @@ String.length msg // block_size) + let ctr_of_octets = Ctr.of_octets end module GHASH : sig type key - val derive : Cstruct.t -> key - val digesti : key:key -> (Cstruct.t Uncommon.iter) -> Cstruct.t + val derive : string -> key + val digesti : key:key -> (string Uncommon.iter) -> string val tagsize : int end = struct - type key = bytes + type key = string let keysize = Native.GHASH.keysize () let tagsize = 16 let derive cs = - assert (cs.len >= tagsize); + assert (String.length cs >= tagsize); let k = Bytes.create keysize in - Native.GHASH.keyinit cs.buffer cs.off k; k - let _cs = create_unsafe tagsize + Native.GHASH.keyinit cs k; + Bytes.unsafe_to_string k let hash0 = Bytes.make tagsize '\x00' - let digesti ~key i = (* Clobbers `_cs`! *) + let digesti ~key i = let res = Bytes.copy hash0 in - i (fun cs -> Native.GHASH.ghash key res cs.buffer cs.off cs.len); - blit_from_bytes res 0 _cs 0 tagsize; _cs + i (fun cs -> Native.GHASH.ghash key res cs (String.length cs)); + Bytes.unsafe_to_string res end module GCM_of (C : S.Core) : S.GCM = struct @@ -256,25 +261,38 @@ module Modes = struct let tag_size = GHASH.tagsize let key_sizes, block_size = C.(key, block) - let z128, h = create block_size, create block_size + let z128, h = String.make block_size '\x00', Bytes.create block_size let of_secret cs = let key = C.e_of_secret cs in - C.encrypt ~key ~blocks:1 z128.buffer z128.off h.buffer h.off; - { key ; hkey = GHASH.derive h } + C.encrypt ~key ~blocks:1 z128 0 h 0; + { key ; hkey = GHASH.derive (Bytes.unsafe_to_string h) } - let bits64 cs = Int64.of_int (length cs * 8) - let pack64s = let _cs = create_unsafe 16 in fun a b -> - BE.set_uint64 _cs 0 a; BE.set_uint64 _cs 8 b; _cs + let bits64 cs = Int64.of_int (String.length cs * 8) - let counter ~hkey nonce = match length nonce with - | 0 -> invalid_arg "GCM: invalid nonce of length 0" - | 12 -> let (w1, w2) = BE.get_uint64 nonce 0, BE.get_uint32 nonce 8 in - (w1, Int64.(shift_left (of_int32 w2) 32 |> add 1L)) - | _ -> CTR.ctr_of_cstruct @@ - GHASH.digesti ~key:hkey @@ iter2 nonce (pack64s 0L (bits64 nonce)) + let pack64s = + let _cs = Bytes.create 16 in + fun a b -> + Bytes.set_int64_be _cs 0 a; + Bytes.set_int64_be _cs 8 b; + Bytes.unsafe_to_string _cs - let tag ~key ~hkey ~ctr ?(adata=Cstruct.empty) cdata = + (* OCaml 4.13 *) + let string_get_int64 s idx = + Bytes.get_int64_be (Bytes.unsafe_of_string s) idx + let string_get_int32 s idx = + Bytes.get_int32_be (Bytes.unsafe_of_string s) idx + + let counter ~hkey nonce = match String.length nonce with + | 0 -> invalid_arg "GCM: invalid nonce of length 0" + | 12 -> + let (w1, w2) = string_get_int64 nonce 0, string_get_int32 nonce 8 in + (w1, Int64.(shift_left (of_int32 w2) 32 |> add 1L)) + | _ -> + CTR.ctr_of_octets @@ + GHASH.digesti ~key:hkey @@ iter2 nonce (pack64s 0L (bits64 nonce)) + + let tag ~key ~hkey ~ctr ?(adata = "") cdata = CTR.encrypt ~key ~ctr @@ GHASH.digesti ~key:hkey @@ iter3 adata cdata (pack64s (bits64 adata) (bits64 cdata)) @@ -287,20 +305,21 @@ module Modes = struct let authenticate_encrypt ~key ~nonce ?adata data = let cdata, ctag = authenticate_encrypt_tag ~key ~nonce ?adata data in - Cstruct.append cdata ctag + cdata ^ ctag let authenticate_decrypt_tag ~key:{ key; hkey } ~nonce ?adata ~tag:tag_data cipher = let ctr = counter ~hkey nonce in let data = CTR.(encrypt ~key ~ctr:(add_ctr ctr 1L) cipher) in let ctag = tag ~key ~hkey ~ctr ?adata cipher in - if Eqaf_cstruct.equal tag_data ctag then Some data else None + if Eqaf.equal tag_data ctag then Some data else None let authenticate_decrypt ~key ~nonce ?adata cdata = - if Cstruct.length cdata < tag_size then + if String.length cdata < tag_size then None else let cipher, tag = - Cstruct.split cdata (Cstruct.length cdata - tag_size) + String.sub cdata 0 (String.length cdata - tag_size), + String.sub cdata (String.length cdata - tag_size) tag_size in authenticate_decrypt_tag ~key ~nonce ?adata ~tag cipher end @@ -317,26 +336,29 @@ module Modes = struct let (key_sizes, block_size) = C.(key, block) - let cipher ~key src dst = - if src.len < block_size || dst.len < block_size then - invalid_arg "src len %d, dst len %d" src.len dst.len; - C.encrypt ~key ~blocks:1 src.buffer src.off dst.buffer dst.off + let cipher ~key src ~src_off dst ~dst_off = + if String.length src - src_off < block_size || Bytes.length dst - dst_off < block_size then + invalid_arg "src len %u, dst len %u" (String.length src - src_off) (Bytes.length dst - dst_off); + C.encrypt ~key ~blocks:1 src src_off dst dst_off - let authenticate_encrypt_tag ~key ~nonce ?(adata = Cstruct.empty) cs = + let authenticate_encrypt_tag ~key ~nonce ?(adata = "") cs = Ccm.generation_encryption ~cipher ~key ~nonce ~maclen:tag_size ~adata cs let authenticate_encrypt ~key ~nonce ?adata cs = let cdata, ctag = authenticate_encrypt_tag ~key ~nonce ?adata cs in - Cstruct.append cdata ctag + cdata ^ ctag - let authenticate_decrypt_tag ~key ~nonce ?(adata = Cstruct.empty) ~tag cs = + let authenticate_decrypt_tag ~key ~nonce ?(adata = "") ~tag cs = Ccm.decryption_verification ~cipher ~key ~nonce ~maclen:tag_size ~adata ~tag cs let authenticate_decrypt ~key ~nonce ?adata data = - if Cstruct.length data < tag_size then + if String.length data < tag_size then None else - let data, tag = Cstruct.split data (Cstruct.length data - tag_size) in + let data, tag = + String.sub data 0 (String.length data - tag_size), + String.sub data (String.length data - tag_size) tag_size + in authenticate_decrypt_tag ~key ~nonce ?adata ~tag data end end @@ -348,19 +370,20 @@ module AES = struct let key = [| 16; 24; 32 |] let block = 16 - type ekey = Native.buffer * int - type dkey = Native.buffer * int + type ekey = string * int + type dkey = string * int - let of_secret_with init { Cstruct.buffer ; off ; len } = + let of_secret_with init key = let rounds = - match len with - | 16|24|32 -> len / 4 + 6 - | _ -> invalid_arg "AES.of_secret: key length %d" len in - let rk = Native.(buffer @@ AES.rk_s rounds) in - init buffer off rk rounds ; - (rk, rounds) + match String.length key with + | 16 | 24 | 32 -> String.length key / 4 + 6 + | _ -> invalid_arg "AES.of_secret: key length %u" (String.length key) + in + let rk = Bytes.create (Native.AES.rk_s rounds) in + init key rk rounds ; + Bytes.unsafe_to_string rk, rounds - let derive_d ?e buf off rk rs = Native.AES.derive_d buf off rk rs e + let derive_d ?e buf rk rs = Native.AES.derive_d buf rk rs e let e_of_secret = of_secret_with Native.AES.derive_e let d_of_secret = of_secret_with (derive_d ?e:None) @@ -395,18 +418,19 @@ module DES = struct let key = [| 24 |] let block = 8 - type ekey = Native.buffer - type dkey = Native.buffer + type ekey = string + type dkey = string let k_s = Native.DES.k_s () - let gen_of_secret ~direction { Cstruct.buffer ; off ; len } = - if len <> 24 then - invalid_arg "DES.of_secret: key length %d" len ; - let key = Native.buffer k_s in - Native.DES.des3key buffer off direction ; - Native.DES.cp3key key ; - key + let gen_of_secret ~direction key = + if String.length key <> 24 then + invalid_arg "DES.of_secret: key length %u" (String.length key) ; + let key = Bytes.of_string key in + let keybuf = Bytes.create k_s in + Native.DES.des3key key direction ; + Native.DES.cp3key keybuf ; + Bytes.unsafe_to_string keybuf let e_of_secret = gen_of_secret ~direction:0 let d_of_secret = gen_of_secret ~direction:1 diff --git a/src/cipher_stream.ml b/src/cipher_stream.ml index d8714a68..87aaec0c 100644 --- a/src/cipher_stream.ml +++ b/src/cipher_stream.ml @@ -2,26 +2,26 @@ open Uncommon module type S = sig type key - type result = { message : Cstruct.t ; key : key } - val of_secret : Cstruct.t -> key - val encrypt : key:key -> Cstruct.t -> result - val decrypt : key:key -> Cstruct.t -> result + type result = { message : string ; key : key } + val of_secret : string -> key + val encrypt : key:key -> string -> result + val decrypt : key:key -> string -> result end module ARC4 = struct type key = int * int * int array - type result = { message : Cstruct.t ; key : key } + type result = { message : string ; key : key } - let of_secret cs = - let len = Cstruct.length cs in + let of_secret buf = + let len = String.length buf in if len < 1 || len > 256 then invalid_arg "ARC4.of_secret: key size %d" len; let s = Array.init 256 (fun x -> x) in let rec loop j = function | 256 -> () | i -> - let x = Cstruct.get_uint8 cs (i mod len) in + let x = string_get_uint8 buf (i mod len) in let si = s.(i) in let j = (j + si + x) land 0xff in let sj = s.(j) in @@ -30,10 +30,10 @@ module ARC4 = struct in ( loop 0 0 ; (0, 0, s) ) - let encrypt ~key:(i, j, s') cs = + let encrypt ~key:(i, j, s') buf = let s = Array.copy s' - and len = Cstruct.length cs in - let res = Cstruct.create len in + and len = String.length buf in + let res = Bytes.create len in let rec mix i j = function | n when n = len -> (i, j, s) | n -> @@ -43,11 +43,11 @@ module ARC4 = struct let sj = s.(j) in s.(i) <- sj ; s.(j) <- si ; let k = s.((si + sj) land 0xff) in - Cstruct.(set_uint8 res n (k lxor get_uint8 cs n)); + Bytes.set_uint8 res n (k lxor string_get_uint8 buf n); mix i j (succ n) in let key' = mix i j 0 in - { key = key' ; message = res } + { key = key' ; message = Bytes.unsafe_to_string res } let decrypt = encrypt diff --git a/src/dune b/src/dune index fdf1b1bd..67287a97 100644 --- a/src/dune +++ b/src/dune @@ -1,7 +1,7 @@ (library (name mirage_crypto) (public_name mirage-crypto) - (libraries cstruct eqaf.cstruct) + (libraries eqaf) (private_modules aead chacha20 ccm cipher_block cipher_stream native poly1305 uncommon) (foreign_stubs diff --git a/src/mirage_crypto.mli b/src/mirage_crypto.mli index 2fd010c3..f0c6f0b9 100644 --- a/src/mirage_crypto.mli +++ b/src/mirage_crypto.mli @@ -29,33 +29,13 @@ module Uncommon : sig @raise Division_by_zero when [y < 1]. *) - (** Addons to the {!Cstruct} interface. *) - module Cs : sig - - val (<+>) : Cstruct.t -> Cstruct.t -> Cstruct.t - (** [<+>] is an alias for [Cstruct.append]. *) - - val xor_into : Cstruct.t -> Cstruct.t -> int -> unit - val xor : Cstruct.t -> Cstruct.t -> Cstruct.t - - (** {2 Private utilities} *) - - val clone : ?len:int -> Cstruct.t -> Cstruct.t - - val b : int -> Cstruct.t - - val of_bytes : int list -> Cstruct.t - - val split3 : Cstruct.t -> int -> int -> Cstruct.t * Cstruct.t * Cstruct.t - end - val imin : int -> int -> int val imax : int -> int -> int val iter2 : 'a -> 'a -> ('a -> unit) -> unit val iter3 : 'a -> 'a -> 'a -> ('a -> unit) -> unit val xor : string -> string -> string - val xor_into : string -> bytes -> int -> unit + val xor_into : string -> ?src_off:int -> bytes -> ?dst_off:int -> int -> unit val invalid_arg : ('a, Format.formatter, unit, unit, unit, 'b) format6 -> 'a val failwith : ('a, Format.formatter, unit, unit, unit, 'b) format6 -> 'a @@ -120,7 +100,7 @@ module type AEAD = sig type key (** The abstract type for the key. *) - val of_secret : Cstruct.t -> key + val of_secret : string -> key (** [of_secret secret] constructs the encryption key corresponding to [secret]. @@ -129,16 +109,16 @@ module type AEAD = sig (** {1 Authenticated encryption and decryption with inline tag} *) - val authenticate_encrypt : key:key -> nonce:Cstruct.t -> ?adata:Cstruct.t -> - Cstruct.t -> Cstruct.t + val authenticate_encrypt : key:key -> nonce:string -> ?adata:string -> + string -> string (** [authenticate_encrypt ~key ~nonce ~adata msg] encrypts [msg] with [key] and [nonce], and appends an authentication tag computed over the encrypted [msg], using [key], [nonce], and [adata]. @raise Invalid_argument if [nonce] is not of the right size. *) - val authenticate_decrypt : key:key -> nonce:Cstruct.t -> ?adata:Cstruct.t -> - Cstruct.t -> Cstruct.t option + val authenticate_decrypt : key:key -> nonce:string -> ?adata:string -> + string -> string option (** [authenticate_decrypt ~key ~nonce ~adata msg] splits [msg] into encrypted data and authentication tag, computes the authentication tag using [key], [nonce], and [adata], and decrypts the encrypted data. If the @@ -148,16 +128,16 @@ module type AEAD = sig (** {1 Authenticated encryption and decryption with tag provided separately} *) - val authenticate_encrypt_tag : key:key -> nonce:Cstruct.t -> - ?adata:Cstruct.t -> Cstruct.t -> Cstruct.t * Cstruct.t + val authenticate_encrypt_tag : key:key -> nonce:string -> + ?adata:string -> string -> string * string (** [authenticate_encrypt_tag ~key ~nonce ~adata msg] encrypts [msg] with [key] and [nonce]. The computed authentication tag is returned separately as second part of the tuple. @raise Invalid_argument if [nonce] is not of the right size. *) - val authenticate_decrypt_tag : key:key -> nonce:Cstruct.t -> - ?adata:Cstruct.t -> tag:Cstruct.t -> Cstruct.t -> Cstruct.t option + val authenticate_decrypt_tag : key:key -> nonce:string -> + ?adata:string -> tag:string -> string -> string option (** [authenticate_decrypt ~key ~nonce ~adata ~tag msg] computes the authentication tag using [key], [nonce], and [adata], and decrypts the encrypted data. If the authentication tags match, the decrypted data is @@ -184,9 +164,9 @@ module Cipher_block : sig (* type ekey *) (* type dkey *) - (* val of_secret : Cstruct.t -> ekey * dkey *) - (* val e_of_secret : Cstruct.t -> ekey *) - (* val d_of_secret : Cstruct.t -> dkey *) + (* val of_secret : string -> ekey * dkey *) + (* val e_of_secret : string -> ekey *) + (* val d_of_secret : string -> dkey *) (* val key : int array *) (* val block : int *) @@ -201,12 +181,12 @@ module Cipher_block : sig module type ECB = sig type key - val of_secret : Cstruct.t -> key + val of_secret : string -> key val key_sizes : int array val block_size : int - val encrypt : key:key -> Cstruct.t -> Cstruct.t - val decrypt : key:key -> Cstruct.t -> Cstruct.t + val encrypt : key:key -> string -> string + val decrypt : key:key -> string -> string end (** {e Cipher-block chaining} mode. *) @@ -214,7 +194,7 @@ module Cipher_block : sig type key - val of_secret : Cstruct.t -> key + val of_secret : string -> key (** Construct the encryption key corresponding to [secret]. @raise Invalid_argument if the length of [secret] is not in @@ -226,20 +206,20 @@ module Cipher_block : sig val block_size : int (** The size of a single block. *) - val encrypt : key:key -> iv:Cstruct.t -> Cstruct.t -> Cstruct.t + val encrypt : key:key -> iv:string -> string -> string (** [encrypt ~key ~iv msg] is [msg] encrypted under [key], using [iv] as the CBC initialization vector. @raise Invalid_argument if [iv] is not [block_size], or [msg] is not [k * block_size] long. *) - val decrypt : key:key -> iv:Cstruct.t -> Cstruct.t -> Cstruct.t + val decrypt : key:key -> iv:string -> string -> string (** [decrypt ~key ~iv msg] is the inverse of [encrypt]. @raise Invalid_argument if [iv] is not [block_size], or [msg] is not [k * block_size] long. *) - val next_iv : iv:Cstruct.t -> Cstruct.t -> Cstruct.t + val next_iv : iv:string -> string -> string (** [next_iv ~iv ciphertext] is the first [iv] {e following} the encryption that used [iv] to produce [ciphertext]. @@ -261,7 +241,7 @@ module Cipher_block : sig type key - val of_secret : Cstruct.t -> key + val of_secret : string -> key (** Construct the encryption key corresponding to [secret]. @raise Invalid_argument if the length of [secret] is not in @@ -275,7 +255,7 @@ module Cipher_block : sig type ctr - val stream : key:key -> ctr:ctr -> int -> Cstruct.t + val stream : key:key -> ctr:ctr -> int -> string (** [stream ~key ~ctr n] is the raw keystream. Keystream is the concatenation of successive encrypted counter states. @@ -291,17 +271,17 @@ module Cipher_block : sig In other words, it is possible to restart a keystream at [block_size] boundaries by manipulating the counter. *) - val encrypt : key:key -> ctr:ctr -> Cstruct.t -> Cstruct.t + val encrypt : key:key -> ctr:ctr -> string -> string (** [encrypt ~key ~ctr msg] is [stream ~key ~ctr ~off (len msg) lxor msg]. *) - val decrypt : key:key -> ctr:ctr -> Cstruct.t -> Cstruct.t + val decrypt : key:key -> ctr:ctr -> string -> string (** [decrypt] is [encrypt]. *) val add_ctr : ctr -> int64 -> ctr (** [add_ctr ctr n] adds [n] to [ctr]. *) - val next_ctr : ctr:ctr -> Cstruct.t -> ctr + val next_ctr : ctr:ctr -> string -> ctr (** [next_ctr ~ctr msg] is the state of the counter after encrypting or decrypting [msg] with the counter [ctr]. @@ -316,8 +296,8 @@ module Cipher_block : sig *) - val ctr_of_cstruct : Cstruct.t -> ctr - (** [ctr_of_cstruct cs] converts the value of [cs] into a counter. *) + val ctr_of_octets : string -> ctr + (** [ctr_of_octets buf] converts the value of [buf] into a counter. *) end (** {e Galois/Counter Mode}. *) @@ -392,10 +372,10 @@ module Cipher_stream : sig (** General stream cipher type. *) module type S = sig type key - type result = { message : Cstruct.t ; key : key } - val of_secret : Cstruct.t -> key - val encrypt : key:key -> Cstruct.t -> result - val decrypt : key:key -> Cstruct.t -> result + type result = { message : string ; key : key } + val of_secret : string -> key + val encrypt : key:key -> string -> result + val decrypt : key:key -> string -> result end (** {e Alleged Rivest Cipher 4}. *) diff --git a/src/native.ml b/src/native.ml index 1cc91aa4..911a050f 100644 --- a/src/native.ml +++ b/src/native.ml @@ -1,65 +1,47 @@ -open Stdlib.Bigarray - -let buffer = Array1.create char c_layout - - -type buffer = (char, int8_unsigned_elt, c_layout) Array1.t - -type off = int -type size = int -type secret = buffer -type key = buffer -type ctx = bytes - - module AES = struct - external enc : buffer -> off -> buffer -> off -> key -> int -> size -> unit = "mc_aes_enc_bc" "mc_aes_enc" [@@noalloc] - external dec : buffer -> off -> buffer -> off -> key -> int -> size -> unit = "mc_aes_dec_bc" "mc_aes_dec" [@@noalloc] - external derive_e : secret -> off -> key -> int -> unit = "mc_aes_derive_e_key" [@@noalloc] - external derive_d : secret -> off -> key -> int -> key option -> unit = "mc_aes_derive_d_key" [@@noalloc] + external enc : string -> int -> bytes -> int -> string -> int -> int -> unit = "mc_aes_enc_bc" "mc_aes_enc" [@@noalloc] + external dec : string -> int -> bytes -> int -> string -> int -> int -> unit = "mc_aes_dec_bc" "mc_aes_dec" [@@noalloc] + external derive_e : string -> bytes -> int -> unit = "mc_aes_derive_e_key" [@@noalloc] + external derive_d : string -> bytes -> int -> string option -> unit = "mc_aes_derive_d_key" [@@noalloc] external rk_s : int -> int = "mc_aes_rk_size" [@@noalloc] external mode : unit -> int = "mc_aes_mode" [@@noalloc] end module DES = struct - external ddes : buffer -> off -> buffer -> off -> int -> unit = "mc_des_ddes" [@@noalloc] - external des3key : secret -> off -> int -> unit = "mc_des_des3key" [@@noalloc] - external cp3key : key -> unit = "mc_des_cp3key" [@@noalloc] - external use3key : key -> unit = "mc_des_use3key" [@@noalloc] + external ddes : string -> int -> bytes -> int -> int -> unit = "mc_des_ddes" [@@noalloc] + external des3key : bytes -> int -> unit = "mc_des_des3key" [@@noalloc] + external cp3key : bytes -> unit = "mc_des_cp3key" [@@noalloc] + external use3key : string -> unit = "mc_des_use3key" [@@noalloc] external k_s : unit -> int = "mc_des_key_size" [@@noalloc] end module Chacha = struct - external round : int -> bytes -> bytes -> off -> unit = "mc_chacha_round" [@@noalloc] + external round : int -> bytes -> bytes -> int -> unit = "mc_chacha_round" [@@noalloc] end module Poly1305 = struct - external init : ctx -> string -> unit = "mc_poly1305_init" [@@noalloc] - external update : ctx -> string -> size -> unit = "mc_poly1305_update" [@@noalloc] - external finalize : ctx -> bytes -> unit = "mc_poly1305_finalize" [@@noalloc] + external init : bytes -> string -> unit = "mc_poly1305_init" [@@noalloc] + external update : bytes -> string -> int -> unit = "mc_poly1305_update" [@@noalloc] + external finalize : bytes -> bytes -> unit = "mc_poly1305_finalize" [@@noalloc] external ctx_size : unit -> int = "mc_poly1305_ctx_size" [@@noalloc] external mac_size : unit -> int = "mc_poly1305_mac_size" [@@noalloc] end module GHASH = struct external keysize : unit -> int = "mc_ghash_key_size" [@@noalloc] - external keyinit : buffer -> off -> bytes -> unit = "mc_ghash_init_key" [@@noalloc] - external ghash : bytes -> bytes -> buffer -> off -> size -> unit = "mc_ghash" [@@noalloc] + external keyinit : string -> bytes -> unit = "mc_ghash_init_key" [@@noalloc] + external ghash : string -> bytes -> string -> int -> unit = "mc_ghash" [@@noalloc] external mode : unit -> int = "mc_ghash_mode" [@@noalloc] end (* XXX TODO * Unsolved: bounds-checked XORs are slowing things down considerably... *) -external xor_into : buffer -> off -> buffer -> off -> size -> unit = "mc_xor_into" [@@noalloc] - -external xor_into_bytes : string -> off -> bytes -> off -> size -> unit = "mc_xor_into_bytes" [@@noalloc] - -external count8be : bytes -> buffer -> off -> blocks:size -> unit = "mc_count_8_be" [@@noalloc] -external count16be : bytes -> buffer -> off -> blocks:size -> unit = "mc_count_16_be" [@@noalloc] -external count16be4 : bytes -> buffer -> off -> blocks:size -> unit = "mc_count_16_be_4" [@@noalloc] +external xor_into_bytes : string -> int -> bytes -> int -> int -> unit = "mc_xor_into_bytes" [@@noalloc] -external blit : buffer -> off -> buffer -> off -> size -> unit = "caml_blit_bigstring_to_bigstring" [@@noalloc] +external count8be : bytes -> bytes -> int -> blocks:int -> unit = "mc_count_8_be" [@@noalloc] +external count16be : bytes -> bytes -> int -> blocks:int -> unit = "mc_count_16_be" [@@noalloc] +external count16be4 : bytes -> bytes -> int -> blocks:int -> unit = "mc_count_16_be_4" [@@noalloc] external misc_mode : unit -> int = "mc_misc_mode" [@@noalloc] diff --git a/src/native/aes_aesni.c b/src/native/aes_aesni.c index e36f4896..fe2eb7da 100644 --- a/src/native/aes_aesni.c +++ b/src/native/aes_aesni.c @@ -17,14 +17,8 @@ #define _S_1111 0x55 #define _S_0000 0x00 -/* - * RKs are currently aligned from the C side on access. Would be better to - * allocate and pass them in pre-aligned. - * - * XXX Get rid of the correction here. - */ static int _mc_aesni_rk_size (uint8_t rounds) { - return (rounds + 1) * 16 + 15; + return (rounds + 1) * 16; } #if defined(__x86_64__) || defined(_WIN64) @@ -61,9 +55,10 @@ static inline void __pack (__m128i *o1, __m128i *o2, __m128i r1, __m128i r2, __m #endif static inline void _mc_aesni_derive_e_key (const uint8_t *key, uint8_t *rk0, uint8_t rounds) { - - __m128i *rk = __rk (rk0); + __m128i schedule[15 + 1]; + __m128i *rk = __rk (schedule); __m128i temp1, temp2; + int i; switch (rounds) { case 10: @@ -142,24 +137,26 @@ static inline void _mc_aesni_derive_e_key (const uint8_t *key, uint8_t *rk0, uin default: ; + }; + + for (i = 0; i <= rounds; i++) { + _mm_storeu_si128((__m128i*) rk0 + i, rk[i]); } } -static inline void _mc_aesni_invert_e_key (const uint8_t *rk0, uint8_t *kr0, uint8_t rounds) { +static inline void _mc_aesni_invert_e_key (const uint8_t *rk1, uint8_t *kr0, uint8_t rounds) { - __m128i *rk1 = __rk (rk0), - *kr = __rk (kr0), - rk[15]; + __m128i rk[15]; for (uint8_t i = 0; i <= rounds; i++) - rk[i] = rk1[i]; + rk[i] = _mm_loadu_si128 ((__m128i*) rk1 + i); - kr[0] = rk[rounds]; + _mm_storeu_si128((__m128i*) kr0 + 0, rk[rounds]); for (uint8_t i = 1; i < rounds; i++) - kr[i] = _mm_aesimc_si128 (rk[rounds - i]); + _mm_storeu_si128((__m128i*) kr0 + i, _mm_aesimc_si128 (rk[rounds - i])); - kr[rounds] = rk[0]; + _mm_storeu_si128((__m128i*) kr0 + rounds, rk[0]); } static void _mc_aesni_derive_d_key (const uint8_t *key, uint8_t *kr, uint8_t rounds, uint8_t *rk) { @@ -174,28 +171,36 @@ static void _mc_aesni_derive_d_key (const uint8_t *key, uint8_t *kr, uint8_t rou static inline void _mc_aesni_enc (const uint8_t src[16], uint8_t dst[16], const uint8_t *rk0, uint8_t rounds) { __m128i r = _mm_loadu_si128 ((__m128i*) src), - *rk = __rk (rk0); + rk; - r = _mm_xor_si128 (r, rk[0]); + rk = _mm_loadu_si128 ((__m128i*) rk0 + 0); + r = _mm_xor_si128 (r, rk); - for (uint8_t i = 1; i < rounds; i++) - r = _mm_aesenc_si128 (r, rk[i]); + for (uint8_t i = 1; i < rounds; i++) { + rk = _mm_loadu_si128 ((__m128i*) rk0 + i); + r = _mm_aesenc_si128 (r, rk); + } - r = _mm_aesenclast_si128 (r, rk[rounds]); + rk = _mm_loadu_si128 ((__m128i*) rk0 + rounds); + r = _mm_aesenclast_si128 (r, rk); _mm_storeu_si128 ((__m128i*) dst, r); } static inline void _mc_aesni_dec (const uint8_t src[16], uint8_t dst[16], const uint8_t *rk0, uint8_t rounds) { __m128i r = _mm_loadu_si128 ((__m128i*) src), - *rk = __rk (rk0); + rk; - r = _mm_xor_si128 (r, rk[0]); + rk = _mm_loadu_si128 ((__m128i*) rk0 + 0); + r = _mm_xor_si128 (r, rk); - for (uint8_t i = 1; i < rounds; i++) - r = _mm_aesdec_si128 (r, rk[i]); + for (uint8_t i = 1; i < rounds; i++) { + rk = _mm_loadu_si128 ((__m128i*) rk0 + i); + r = _mm_aesdec_si128 (r, rk); + } - r = _mm_aesdeclast_si128 (r, rk[rounds]); + rk = _mm_loadu_si128 ((__m128i*) rk0 + rounds); + r = _mm_aesdeclast_si128 (r, rk); _mm_storeu_si128 ((__m128i*) dst, r); } @@ -203,7 +208,7 @@ static inline void _mc_aesni_enc8 (const uint8_t src[128], uint8_t dst[128], con __m128i *in = (__m128i*) src, *out = (__m128i*) dst, - *rk = __rk (rk0); + rk; __m128i r0 = _mm_loadu_si128 (in ), r1 = _mm_loadu_si128 (in + 1), @@ -214,34 +219,37 @@ static inline void _mc_aesni_enc8 (const uint8_t src[128], uint8_t dst[128], con r6 = _mm_loadu_si128 (in + 6), r7 = _mm_loadu_si128 (in + 7); - r0 = _mm_xor_si128 (r0, rk[0]); - r1 = _mm_xor_si128 (r1, rk[0]); - r2 = _mm_xor_si128 (r2, rk[0]); - r3 = _mm_xor_si128 (r3, rk[0]); - r4 = _mm_xor_si128 (r4, rk[0]); - r5 = _mm_xor_si128 (r5, rk[0]); - r6 = _mm_xor_si128 (r6, rk[0]); - r7 = _mm_xor_si128 (r7, rk[0]); + rk = _mm_loadu_si128 ((__m128i*) rk0 + 0); + r0 = _mm_xor_si128 (r0, rk); + r1 = _mm_xor_si128 (r1, rk); + r2 = _mm_xor_si128 (r2, rk); + r3 = _mm_xor_si128 (r3, rk); + r4 = _mm_xor_si128 (r4, rk); + r5 = _mm_xor_si128 (r5, rk); + r6 = _mm_xor_si128 (r6, rk); + r7 = _mm_xor_si128 (r7, rk); for (uint8_t i = 1; i < rounds; i++) { - r0 = _mm_aesenc_si128 (r0, rk[i]); - r1 = _mm_aesenc_si128 (r1, rk[i]); - r2 = _mm_aesenc_si128 (r2, rk[i]); - r3 = _mm_aesenc_si128 (r3, rk[i]); - r4 = _mm_aesenc_si128 (r4, rk[i]); - r5 = _mm_aesenc_si128 (r5, rk[i]); - r6 = _mm_aesenc_si128 (r6, rk[i]); - r7 = _mm_aesenc_si128 (r7, rk[i]); + rk = _mm_loadu_si128 ((__m128i*) rk0 + i); + r0 = _mm_aesenc_si128 (r0, rk); + r1 = _mm_aesenc_si128 (r1, rk); + r2 = _mm_aesenc_si128 (r2, rk); + r3 = _mm_aesenc_si128 (r3, rk); + r4 = _mm_aesenc_si128 (r4, rk); + r5 = _mm_aesenc_si128 (r5, rk); + r6 = _mm_aesenc_si128 (r6, rk); + r7 = _mm_aesenc_si128 (r7, rk); } - r0 = _mm_aesenclast_si128 (r0, rk[rounds]); - r1 = _mm_aesenclast_si128 (r1, rk[rounds]); - r2 = _mm_aesenclast_si128 (r2, rk[rounds]); - r3 = _mm_aesenclast_si128 (r3, rk[rounds]); - r4 = _mm_aesenclast_si128 (r4, rk[rounds]); - r5 = _mm_aesenclast_si128 (r5, rk[rounds]); - r6 = _mm_aesenclast_si128 (r6, rk[rounds]); - r7 = _mm_aesenclast_si128 (r7, rk[rounds]); + rk = _mm_loadu_si128 ((__m128i*) rk0 + rounds); + r0 = _mm_aesenclast_si128 (r0, rk); + r1 = _mm_aesenclast_si128 (r1, rk); + r2 = _mm_aesenclast_si128 (r2, rk); + r3 = _mm_aesenclast_si128 (r3, rk); + r4 = _mm_aesenclast_si128 (r4, rk); + r5 = _mm_aesenclast_si128 (r5, rk); + r6 = _mm_aesenclast_si128 (r6, rk); + r7 = _mm_aesenclast_si128 (r7, rk); _mm_storeu_si128 (out , r0); _mm_storeu_si128 (out + 1, r1); @@ -257,7 +265,7 @@ static inline void _mc_aesni_dec8 (const uint8_t src[128], uint8_t dst[128], con __m128i *in = (__m128i*) src, *out = (__m128i*) dst, - *rk = __rk (rk0); + rk; __m128i r0 = _mm_loadu_si128 (in ), r1 = _mm_loadu_si128 (in + 1), @@ -268,34 +276,37 @@ static inline void _mc_aesni_dec8 (const uint8_t src[128], uint8_t dst[128], con r6 = _mm_loadu_si128 (in + 6), r7 = _mm_loadu_si128 (in + 7); - r0 = _mm_xor_si128 (r0, rk[0]); - r1 = _mm_xor_si128 (r1, rk[0]); - r2 = _mm_xor_si128 (r2, rk[0]); - r3 = _mm_xor_si128 (r3, rk[0]); - r4 = _mm_xor_si128 (r4, rk[0]); - r5 = _mm_xor_si128 (r5, rk[0]); - r6 = _mm_xor_si128 (r6, rk[0]); - r7 = _mm_xor_si128 (r7, rk[0]); + rk = _mm_loadu_si128 ((__m128i*) rk0 + 0); + r0 = _mm_xor_si128 (r0, rk); + r1 = _mm_xor_si128 (r1, rk); + r2 = _mm_xor_si128 (r2, rk); + r3 = _mm_xor_si128 (r3, rk); + r4 = _mm_xor_si128 (r4, rk); + r5 = _mm_xor_si128 (r5, rk); + r6 = _mm_xor_si128 (r6, rk); + r7 = _mm_xor_si128 (r7, rk); for (uint8_t i = 1; i < rounds; i++) { - r0 = _mm_aesdec_si128 (r0, rk[i]); - r1 = _mm_aesdec_si128 (r1, rk[i]); - r2 = _mm_aesdec_si128 (r2, rk[i]); - r3 = _mm_aesdec_si128 (r3, rk[i]); - r4 = _mm_aesdec_si128 (r4, rk[i]); - r5 = _mm_aesdec_si128 (r5, rk[i]); - r6 = _mm_aesdec_si128 (r6, rk[i]); - r7 = _mm_aesdec_si128 (r7, rk[i]); + rk = _mm_loadu_si128 ((__m128i*) rk0 + i); + r0 = _mm_aesdec_si128 (r0, rk); + r1 = _mm_aesdec_si128 (r1, rk); + r2 = _mm_aesdec_si128 (r2, rk); + r3 = _mm_aesdec_si128 (r3, rk); + r4 = _mm_aesdec_si128 (r4, rk); + r5 = _mm_aesdec_si128 (r5, rk); + r6 = _mm_aesdec_si128 (r6, rk); + r7 = _mm_aesdec_si128 (r7, rk); } - r0 = _mm_aesdeclast_si128 (r0, rk[rounds]); - r1 = _mm_aesdeclast_si128 (r1, rk[rounds]); - r2 = _mm_aesdeclast_si128 (r2, rk[rounds]); - r3 = _mm_aesdeclast_si128 (r3, rk[rounds]); - r4 = _mm_aesdeclast_si128 (r4, rk[rounds]); - r5 = _mm_aesdeclast_si128 (r5, rk[rounds]); - r6 = _mm_aesdeclast_si128 (r6, rk[rounds]); - r7 = _mm_aesdeclast_si128 (r7, rk[rounds]); + rk = _mm_loadu_si128 ((__m128i*) rk0 + rounds); + r0 = _mm_aesdeclast_si128 (r0, rk); + r1 = _mm_aesdeclast_si128 (r1, rk); + r2 = _mm_aesdeclast_si128 (r2, rk); + r3 = _mm_aesdeclast_si128 (r3, rk); + r4 = _mm_aesdeclast_si128 (r4, rk); + r5 = _mm_aesdeclast_si128 (r5, rk); + r6 = _mm_aesdeclast_si128 (r6, rk); + r7 = _mm_aesdeclast_si128 (r7, rk); _mm_storeu_si128 (out , r0); _mm_storeu_si128 (out + 1, r1); @@ -361,23 +372,23 @@ mc_aes_rk_size (value rounds) { } CAMLprim value -mc_aes_derive_e_key (value key, value off1, value rk, value rounds) { +mc_aes_derive_e_key (value key, value rk, value rounds) { _mc_switch_accel(aesni, - mc_aes_derive_e_key_generic(key, off1, rk, rounds), - _mc_aesni_derive_e_key (_ba_uint8_off (key, off1), - _ba_uint8 (rk), + mc_aes_derive_e_key_generic(key, rk, rounds), + _mc_aesni_derive_e_key (_st_uint8 (key), + _bp_uint8 (rk), Int_val (rounds))) return Val_unit; } CAMLprim value -mc_aes_derive_d_key (value key, value off1, value kr, value rounds, value rk) { +mc_aes_derive_d_key (value key, value kr, value rounds, value rk) { _mc_switch_accel(aesni, - mc_aes_derive_d_key_generic(key, off1, kr, rounds, rk), - _mc_aesni_derive_d_key (_ba_uint8_off (key, off1), - _ba_uint8 (kr), + mc_aes_derive_d_key_generic(key, kr, rounds, rk), + _mc_aesni_derive_d_key (_st_uint8 (key), + _bp_uint8 (kr), Int_val (rounds), - Is_block(rk) ? _ba_uint8(Field(rk, 0)) : 0)) + Is_block(rk) ? _bp_uint8(Field(rk, 0)) : 0)) return Val_unit; } @@ -385,9 +396,9 @@ CAMLprim value mc_aes_enc (value src, value off1, value dst, value off2, value rk, value rounds, value blocks) { _mc_switch_accel(aesni, mc_aes_enc_generic(src, off1, dst, off2, rk, rounds, blocks), - _mc_aesni_enc_blocks ( _ba_uint8_off (src, off1), - _ba_uint8_off (dst, off2), - _ba_uint8 (rk), + _mc_aesni_enc_blocks ( _st_uint8_off (src, off1), + _bp_uint8_off (dst, off2), + _st_uint8 (rk), Int_val (rounds), Int_val (blocks) )) return Val_unit; @@ -397,9 +408,9 @@ CAMLprim value mc_aes_dec (value src, value off1, value dst, value off2, value rk, value rounds, value blocks) { _mc_switch_accel(aesni, mc_aes_dec_generic(src, off1, dst, off2, rk, rounds, blocks), - _mc_aesni_dec_blocks ( _ba_uint8_off (src, off1), - _ba_uint8_off (dst, off2), - _ba_uint8 (rk), + _mc_aesni_dec_blocks ( _st_uint8_off (src, off1), + _bp_uint8_off (dst, off2), + _st_uint8 (rk), Int_val (rounds), Int_val (blocks) )) return Val_unit; diff --git a/src/native/aes_generic.c b/src/native/aes_generic.c index da1dece3..1f4aa0bf 100644 --- a/src/native/aes_generic.c +++ b/src/native/aes_generic.c @@ -1232,26 +1232,26 @@ mc_aes_rk_size_generic (value rounds) { } CAMLprim value -mc_aes_derive_e_key_generic (value key, value off1, value rk, value rounds) { - mc_rijndaelSetupEncrypt (_ba_uint32 (rk), - _ba_uint8_off (key, off1), +mc_aes_derive_e_key_generic (value key, value rk, value rounds) { + mc_rijndaelSetupEncrypt (_bp_uint32 (rk), + _st_uint8 (key), keybits_of_r (Int_val (rounds))); return Val_unit; } CAMLprim value -mc_aes_derive_d_key_generic (value key, value off1, value kr, value rounds, value __unused (rk)) { - mc_rijndaelSetupDecrypt (_ba_uint32 (kr), - _ba_uint8_off (key, off1), +mc_aes_derive_d_key_generic (value key, value kr, value rounds, value __unused (rk)) { + mc_rijndaelSetupDecrypt (_bp_uint32 (kr), + _st_uint8 (key), keybits_of_r (Int_val (rounds))); return Val_unit; } CAMLprim value mc_aes_enc_generic (value src, value off1, value dst, value off2, value rk, value rounds, value blocks) { - _mc_aes_enc_blocks ( _ba_uint8_off (src, off1), - _ba_uint8_off (dst, off2), - _ba_uint32 (rk), + _mc_aes_enc_blocks ( _st_uint8_off (src, off1), + _bp_uint8_off (dst, off2), + _st_uint32 (rk), Int_val (rounds), Int_val (blocks) ); return Val_unit; @@ -1259,9 +1259,9 @@ mc_aes_enc_generic (value src, value off1, value dst, value off2, value rk, valu CAMLprim value mc_aes_dec_generic (value src, value off1, value dst, value off2, value rk, value rounds, value blocks) { - _mc_aes_dec_blocks ( _ba_uint8_off (src, off1), - _ba_uint8_off (dst, off2), - _ba_uint32 (rk), + _mc_aes_dec_blocks ( _st_uint8_off(src, off1), + _bp_uint8_off(dst, off2), + _st_uint32 (rk), Int_val (rounds), Int_val (blocks) ); return Val_unit; diff --git a/src/native/des_generic.c b/src/native/des_generic.c index 99332a03..c743cc56 100644 --- a/src/native/des_generic.c +++ b/src/native/des_generic.c @@ -18,7 +18,7 @@ #include "mirage_crypto.h" #include "des_generic.h" -static void scrunch(unsigned char *, unsigned long *); +static void scrunch(const unsigned char *, unsigned long *); static void unscrun(unsigned long *, unsigned char *); static void desfunc(unsigned long *, unsigned long *); static void cookey(unsigned long *); @@ -145,7 +145,7 @@ void mc_des(unsigned char inblock[8], unsigned char outblock[8]) } -static void scrunch(unsigned char *outof, unsigned long *into) +static void scrunch(const unsigned char *outof, unsigned long *into) { *into = (*outof++ & 0xffL) << 24; *into |= (*outof++ & 0xffL) << 16; @@ -404,7 +404,7 @@ void mc_des2key(unsigned char hexkey[16], short mode) /* stomps on Kn3 too */ return; } -void mc_Ddes(unsigned char from[8], unsigned char into[8]) +void mc_Ddes(const unsigned char from[8], unsigned char into[8]) { unsigned long work[2]; @@ -657,7 +657,7 @@ void mc_make3key(char *aptr /* NULL-terminated */, unsigned char kptr[24]) /* OCaml front-end */ -static inline void _mc_ddes (unsigned char *src, unsigned char *dst, unsigned int blocks) { +static inline void _mc_ddes (const unsigned char *src, unsigned char *dst, unsigned int blocks) { while (blocks --) { mc_Ddes (src, dst); src += 8 ; dst += 8; @@ -670,25 +670,25 @@ mc_des_key_size (__unit ()) { } CAMLprim value -mc_des_des3key (value key, value off, value direction) { - mc_des3key (_ba_uint8_off (key, off), Int_val (direction)); +mc_des_des3key (value key, value direction) { + mc_des3key (_bp_uint8 (key), Int_val (direction)); return Val_unit; } CAMLprim value mc_des_cp3key (value dst) { - mc_cp3key ((unsigned long *) _ba_uint8 (dst)); + mc_cp3key ((unsigned long *) _bp_uint8 (dst)); return Val_unit; } CAMLprim value mc_des_use3key (value src) { - mc_use3key ((unsigned long *) _ba_uint8 (src)); + mc_use3key ((unsigned long *) _bp_uint8 (src)); return Val_unit; } CAMLprim value mc_des_ddes (value src, value off1, value dst, value off2, value blocks) { - _mc_ddes (_ba_uint8_off (src, off1), _ba_uint8_off (dst, off2), Int_val (blocks)); + _mc_ddes (_st_uint8_off (src, off1), _bp_uint8_off (dst, off2), Int_val (blocks)); return Val_unit; } diff --git a/src/native/des_generic.h b/src/native/des_generic.h index df88f73e..4c3fe0ed 100644 --- a/src/native/des_generic.h +++ b/src/native/des_generic.h @@ -72,7 +72,7 @@ extern void mc_des2key(unsigned char [16], short); * NOTE: this clobbers all three key registers! */ -extern void mc_Ddes(unsigned char [8], unsigned char [8]); +extern void mc_Ddes(const unsigned char [8], unsigned char [8]); /* from[8] to[8] * Encrypts/Decrypts (according to the keyS currently loaded in the * internal key registerS) one block of eight bytes at address 'from' diff --git a/src/native/ghash_ctmul.c b/src/native/ghash_ctmul.c index e18edf3c..7788fd05 100644 --- a/src/native/ghash_ctmul.c +++ b/src/native/ghash_ctmul.c @@ -284,14 +284,14 @@ static inline void __copy (uint64_t key[2], uint32_t m[4]) { m[3] = key[1] >> 32; } -CAMLprim value mc_ghash_init_key_generic (value key, value off, value m) { +CAMLprim value mc_ghash_init_key_generic (value key, value m) { //push key at off into m - __copy ((uint64_t *) _ba_uint8_off(key, off), (uint32_t *) m); + __copy ((uint64_t *) _st_uint8(key), (uint32_t *) m); return Val_unit; } -CAMLprim value mc_ghash_generic (value m, value hash, value src, value off, value len) { - br_ghash_ctmul(Bp_val(hash), Bp_val(m), _ba_uint8_off(src, off), Int_val(len)); +CAMLprim value mc_ghash_generic (value m, value hash, value src, value len) { + br_ghash_ctmul(Bp_val(hash), Bp_val(m), _st_uint8(src), Int_val(len)); return Val_unit; } diff --git a/src/native/ghash_generic.c b/src/native/ghash_generic.c index cdb3faf1..2cc49532 100644 --- a/src/native/ghash_generic.c +++ b/src/native/ghash_generic.c @@ -95,15 +95,15 @@ CAMLprim value mc_ghash_key_size_generic (__unit ()) { return Val_int (sizeof (__uint128_t) * __t_size); } -CAMLprim value mc_ghash_init_key_generic (value key, value off, value m) { - __derive ((uint64_t *) _ba_uint8_off (key, off), (__uint128_t *) Bp_val (m)); +CAMLprim value mc_ghash_init_key_generic (value key, value m) { + __derive ((uint64_t *) _st_uint8 (key), (__uint128_t *) Bp_val (m)); return Val_unit; } CAMLprim value -mc_ghash_generic (value m, value hash, value src, value off, value len) { +mc_ghash_generic (value m, value hash, value src, value len) { __ghash ((__uint128_t *) Bp_val (m), (uint64_t *) Bp_val (hash), - _ba_uint8_off (src, off), Int_val (len) ); + _st_uint8 (src), Int_val (len) ); return Val_unit; } diff --git a/src/native/ghash_pclmul.c b/src/native/ghash_pclmul.c index 0b28d174..58ca02ea 100644 --- a/src/native/ghash_pclmul.c +++ b/src/native/ghash_pclmul.c @@ -196,19 +196,19 @@ CAMLprim value mc_ghash_key_size (__unit ()) { return s; } -CAMLprim value mc_ghash_init_key (value key, value off, value m) { +CAMLprim value mc_ghash_init_key (value key, value m) { _mc_switch_accel(pclmul, - mc_ghash_init_key_generic(key, off, m), - __derive ((__m128i *) _ba_uint8_off (key, off), (__m128i *) Bp_val (m))) + mc_ghash_init_key_generic(key, m), + __derive ((__m128i *) _st_uint8 (key), (__m128i *) Bp_val (m))) return Val_unit; } CAMLprim value -mc_ghash (value k, value hash, value src, value off, value len) { +mc_ghash (value k, value hash, value src, value len) { _mc_switch_accel(pclmul, - mc_ghash_generic(k, hash, src, off, len), + mc_ghash_generic(k, hash, src, len), __ghash ( (__m128i *) Bp_val (k), (__m128i *) Bp_val (hash), - (__m128i *) _ba_uint8_off (src, off), Int_val (len) )) + (__m128i *) _st_uint8 (src), Int_val (len) )) return Val_unit; } diff --git a/src/native/mirage_crypto.h b/src/native/mirage_crypto.h index 7c70d4c2..6608a1b1 100644 --- a/src/native/mirage_crypto.h +++ b/src/native/mirage_crypto.h @@ -6,7 +6,6 @@ #include "bitfn.h" #include -#include #ifdef ACCELERATE # ifdef _MSC_VER @@ -72,12 +71,8 @@ extern struct _mc_cpu_features mc_detected_cpu_features; #define __unit() value __unused(_) #define _st_uint8(v) ((const uint8_t*) (String_val(v))) - -#define _ba_uint8_off(ba, off) ((uint8_t*) Caml_ba_data_val (ba) + Long_val (off)) -#define _ba_uint32_off(ba, off) ((uint32_t*) Caml_ba_data_val (ba) + Long_val (off)) - -#define _ba_uint8(ba) ((uint8_t*) Caml_ba_data_val (ba)) -#define _ba_uint32(ba) ((uint32_t*) Caml_ba_data_val (ba)) +#define _st_uint32(v) ((const uint32_t*) (String_val(v))) +#define _st_uint8_off(v, off) ((const uint8_t*)(String_val(v) + Long_val(off))) #define _bp_uint8_off(bp, off) ((uint8_t *) Bp_val (bp) + Long_val (off)) #define _bp_uint8(bp) ((uint8_t *) Bp_val (bp)) @@ -94,10 +89,10 @@ extern struct _mc_cpu_features mc_detected_cpu_features; CAMLprim value mc_aes_rk_size_generic (value rounds); CAMLprim value -mc_aes_derive_e_key_generic (value key, value off1, value rk, value rounds); +mc_aes_derive_e_key_generic (value key, value rk, value rounds); CAMLprim value -mc_aes_derive_d_key_generic (value key, value off1, value kr, value rounds, value __unused (rk)); +mc_aes_derive_d_key_generic (value key, value kr, value rounds, value __unused (rk)); CAMLprim value mc_aes_enc_generic (value src, value off1, value dst, value off2, value rk, value rounds, value blocks); @@ -107,10 +102,10 @@ mc_aes_dec_generic (value src, value off1, value dst, value off2, value rk, valu CAMLprim value mc_ghash_key_size_generic (__unit ()); -CAMLprim value mc_ghash_init_key_generic (value key, value off, value m); +CAMLprim value mc_ghash_init_key_generic (value key, value m); CAMLprim value -mc_ghash_generic (value m, value hash, value src, value off, value len); +mc_ghash_generic (value m, value hash, value src, value len); CAMLprim value mc_xor_into_generic (value b1, value off1, value b2, value off2, value n); diff --git a/src/native/misc.c b/src/native/misc.c index 97083d42..dea76e18 100644 --- a/src/native/misc.c +++ b/src/native/misc.c @@ -53,22 +53,16 @@ static inline void _mc_count_16_be_4 (uint64_t *init, uint64_t *dst, size_t bloc } } -CAMLprim value -mc_xor_into_generic (value b1, value off1, value b2, value off2, value n) { - xor_into (_ba_uint8_off (b1, off1), _ba_uint8_off (b2, off2), Int_val (n)); - return Val_unit; -} - CAMLprim value mc_xor_into_bytes_generic (value b1, value off1, value b2, value off2, value n) { - xor_into (_st_uint8 (b1) + Long_val(off1), Bytes_val (b2) + Long_val(off2), Int_val (n)); + xor_into (_st_uint8_off (b1, off1), _bp_uint8_off (b2, off2), Int_val (n)); return Val_unit; } #define __export_counter(name, f) \ CAMLprim value name (value ctr, value dst, value off, value blocks) { \ f ( (uint64_t*) Bp_val (ctr), \ - (uint64_t*) _ba_uint8_off (dst, off), Long_val (blocks) ); \ + (uint64_t*) _bp_uint8_off (dst, off), Long_val (blocks) ); \ return Val_unit; \ } diff --git a/src/native/misc_sse.c b/src/native/misc_sse.c index fe322745..a5a068c5 100644 --- a/src/native/misc_sse.c +++ b/src/native/misc_sse.c @@ -39,19 +39,11 @@ static inline void _mc_count_16_be_4 (uint64_t *init, uint64_t *dst, size_t bloc #endif /* __mc_ACCELERATE__ */ -CAMLprim value -mc_xor_into (value b1, value off1, value b2, value off2, value n) { - _mc_switch_accel(ssse3, - mc_xor_into_generic(b1, off1, b2, off2, n), - xor_into (_ba_uint8_off (b1, off1), _ba_uint8_off (b2, off2), Int_val (n))) - return Val_unit; -} - CAMLprim value mc_xor_into_bytes (value b1, value off1, value b2, value off2, value n) { _mc_switch_accel(ssse3, mc_xor_into_bytes_generic(b1, off1, b2, off2, n), - xor_into (_st_uint8 (b1) + Long_val(off1), Bytes_val (b2) + Long_val(off2), Int_val (n))) + xor_into (_st_uint8_off (b1, off1), _bp_uint8_off (b2, off2), Int_val (n))) return Val_unit; } @@ -60,7 +52,7 @@ mc_xor_into_bytes (value b1, value off1, value b2, value off2, value n) { _mc_switch_accel(ssse3, \ name##_generic (ctr, dst, off, blocks), \ f ( (uint64_t*) Bp_val (ctr), \ - (uint64_t*) _ba_uint8_off (dst, off), Long_val (blocks) )) \ + (uint64_t*) _bp_uint8_off (dst, off), Long_val (blocks) )) \ return Val_unit; \ } diff --git a/src/poly1305.ml b/src/poly1305.ml index 08df20ab..8d4caef8 100644 --- a/src/poly1305.ml +++ b/src/poly1305.ml @@ -22,7 +22,7 @@ module It : S = struct module P = Native.Poly1305 let mac_size = P.mac_size () - type t = Native.ctx + type t = bytes let dup = Bytes.copy diff --git a/src/uncommon.ml b/src/uncommon.ml index ca23e16b..0f282e53 100644 --- a/src/uncommon.ml +++ b/src/uncommon.ml @@ -18,8 +18,8 @@ type 'a iter = ('a -> unit) -> unit let iter2 a b f = f a; f b let iter3 a b c f = f a; f b; f c -let xor_into src dst n = - Native.xor_into_bytes src 0 dst 0 n +let xor_into src ?(src_off = 0) dst ?(dst_off = 0) n = + Native.xor_into_bytes src src_off dst dst_off n let xor a b = assert (String.length a = String.length b); @@ -27,51 +27,6 @@ let xor a b = xor_into a b' (Bytes.length b'); Bytes.unsafe_to_string b' -module Cs = struct - - open Cstruct - - let (<+>) = append - - let clone ?len cs = - let len = match len with None -> cs.len | Some x -> x in - let cs' = create_unsafe len in - ( blit cs 0 cs' 0 len ; cs' ) - - let xor_into src dst n = - if n > imin (length src) (length dst) then - invalid_arg "Uncommon.Cs.xor_into: buffers to small (need %d)" n - else Native.xor_into src.buffer src.off dst.buffer dst.off n - - let xor cs1 cs2 = - let len = imin (length cs1) (length cs2) in - let cs = clone ~len cs2 in - ( xor_into cs1 cs len ; cs ) - - let split3 cs l1 l2 = - let l12 = l1 + l2 in - (sub cs 0 l1, sub cs l1 l2, sub cs l12 (length cs - l12)) - - let rpad cs size x = - let l = length cs and cs' = Cstruct.create_unsafe size in - if size < l then invalid_arg "Uncommon.Cs.rpad: size < len"; - blit cs 0 cs' 0 l ; - memset (sub cs' l (size - l)) x ; - cs' - - let lpad cs size x = - let l = length cs and cs' = Cstruct.create_unsafe size in - if size < l then invalid_arg "Uncommon.Cs.lpad: size < len"; - blit cs 0 cs' (size - l) l ; - memset (sub cs' 0 (size - l)) x ; - cs' - - let of_bytes xs = - let cs = Cstruct.create_unsafe @@ List.length xs in - List.iteri (fun i x -> set_uint8 cs i x) xs; - cs - - let b x = - let cs = Cstruct.create_unsafe 1 in ( set_uint8 cs 0 x ; cs ) - -end +(* revise once OCaml 4.13 is the lower bound *) +let string_get_uint8 buf idx = + Bytes.get_uint8 (Bytes.unsafe_of_string buf) idx diff --git a/tests/dune b/tests/dune index d0354cdf..05ae7adb 100644 --- a/tests/dune +++ b/tests/dune @@ -1,6 +1,6 @@ (library (name test_common) - (libraries mirage-crypto ounit2) + (libraries mirage-crypto ounit2 ohex) (modules test_common) (optional)) @@ -29,19 +29,19 @@ (modules test_entropy_collection) (package mirage-crypto-rng-mirage) (libraries mirage-crypto-rng-mirage mirage-unix mirage-time-unix - mirage-clock-unix duration)) + mirage-clock-unix duration ohex)) (test (name test_entropy_collection_async) (modules test_entropy_collection_async) (package mirage-crypto-rng-async) - (libraries mirage-crypto-rng-async)) + (libraries mirage-crypto-rng-async ohex)) (test (name test_entropy) (modules test_entropy) (package mirage-crypto-rng) - (libraries mirage-crypto-rng)) + (libraries mirage-crypto-rng ohex)) (test (name test_ec) @@ -63,5 +63,5 @@ (tests (names test_eio_rng test_eio_entropy_collection) (modules test_eio_rng test_eio_entropy_collection) - (libraries mirage-crypto-rng-eio duration eio_main) + (libraries mirage-crypto-rng-eio duration eio_main ohex) (package mirage-crypto-rng-eio)) diff --git a/tests/test_base.ml b/tests/test_base.ml index fe38f2eb..bb92e93e 100644 --- a/tests/test_base.ml +++ b/tests/test_base.ml @@ -7,7 +7,7 @@ open Test_common (* Xor *) let xor_cases = - cases_of (f2_eq ~msg:"xor" Uncommon.Cs.xor) [ + cases_of (f2_eq ~msg:"xor" Uncommon.xor) [ "00 01 02 03 04 05 06 07 08 09 0a 0b 0c" , "0c 0b 0a 09 08 07 06 05 04 03 02 01 00" , "0c 0a 08 0a 0c 02 00 02 0c 0a 08 0a 0c" ; @@ -16,9 +16,7 @@ let xor_cases = "0f 0e 0d 0c 0b 0a 09 08 07 06 05 04 03 02 01 00" , "0f 0f 0f 0f 0f 0f 0f 0f 0f 0f 0f 0f 0f 0f 0f 0f" ; - "00 01 02", "00", "00" ; - - "00", "00 01 02", "00" ; + "00", "00", "00" ; "", "", "" ; ] diff --git a/tests/test_cipher.ml b/tests/test_cipher.ml index d43eb139..3bbf0c1d 100644 --- a/tests/test_cipher.ml +++ b/tests/test_cipher.ml @@ -20,8 +20,8 @@ let aes_ecb_cases = and check (key, out) _ = let enc = AES.ECB.encrypt ~key nist_sp_800_38a in let dec = AES.ECB.decrypt ~key enc in - assert_cs_equal ~msg:"ciphertext" out enc ; - assert_cs_equal ~msg:"plaintext" nist_sp_800_38a dec in + assert_oct_equal ~msg:"ciphertext" out enc ; + assert_oct_equal ~msg:"plaintext" nist_sp_800_38a dec in cases_of check [ case ~key: "2b 7e 15 16 28 ae d2 a6 ab f7 15 88 09 cf 4f 3c" @@ -53,8 +53,8 @@ let aes_cbc_cases = and check (key, iv, out) _ = let enc = AES.CBC.encrypt ~key ~iv nist_sp_800_38a in let dec = AES.CBC.decrypt ~key ~iv enc in - assert_cs_equal ~msg:"ciphertext" out enc ; - assert_cs_equal ~msg:"plaintext" nist_sp_800_38a dec in + assert_oct_equal ~msg:"ciphertext" out enc ; + assert_oct_equal ~msg:"plaintext" nist_sp_800_38a dec in cases_of check [ case ~key: "2b 7e 15 16 28 ae d2 a6 ab f7 15 88 09 cf 4f 3c" @@ -85,14 +85,14 @@ let aes_ctr_cases = let case ~key ~ctr ~out ~ctr1 = test_case @@ fun _ -> let open Cipher_block.AES.CTR in let key = vx key |> of_secret - and ctr = vx ctr |> ctr_of_cstruct - and ctr1 = vx ctr1 |> ctr_of_cstruct + and ctr = vx ctr |> ctr_of_octets + and ctr1 = vx ctr1 |> ctr_of_octets and out = vx out in let enc = encrypt ~key ~ctr nist_sp_800_38a in let dec = decrypt ~key ~ctr enc in - assert_cs_equal ~msg:"cipher" out enc; - assert_cs_equal ~msg:"plain" nist_sp_800_38a dec; - let blocks = Cstruct.length nist_sp_800_38a / block_size in + assert_oct_equal ~msg:"cipher" out enc; + assert_oct_equal ~msg:"plain" nist_sp_800_38a dec; + let blocks = String.length nist_sp_800_38a / block_size in assert_equal ~msg:"counters" ctr1 (add_ctr ctr (Int64.of_int blocks)) in [ case ~key: "2b7e1516 28aed2a6 abf71588 09cf4f3c" @@ -145,8 +145,8 @@ let gcm_cases = | None -> assert_failure "GCM decryption broken" | Some data -> data in - assert_cs_equal ~msg:"ciphertext" (Cstruct.append c t) cipher ; - assert_cs_equal ~msg:"decrypted plaintext" p pdata + assert_oct_equal ~msg:"ciphertext" (c ^ t) cipher ; + assert_oct_equal ~msg:"decrypted plaintext" p pdata in cases_of check [ @@ -292,9 +292,9 @@ let ccm_cases = let check (key, p, adata, nonce, c) _ = let cip = authenticate_encrypt ~key ~nonce ~adata p in - assert_cs_equal ~msg:"encrypt" c cip ; + assert_oct_equal ~msg:"encrypt" c cip ; match authenticate_decrypt ~key ~nonce ~adata c with - | Some x -> assert_cs_equal ~msg:"decrypt" p x + | Some x -> assert_oct_equal ~msg:"decrypt" p x | None -> assert_failure "CCM decryption broken" in @@ -330,18 +330,18 @@ let ccm_regressions = (* see RFC 3610 Section 2.1, AD of length 0 should be same as no AD *) let key = of_secret (vx "000102030405060708090a0b0c0d0e0f") and nonce = vx "0001020304050607" - and plaintext = Cstruct.of_string "hello" + and plaintext = "hello" in - assert_cs_equal ~msg:"CCM no vs empty ad" + assert_oct_equal ~msg:"CCM no vs empty ad" (authenticate_encrypt ~key ~nonce plaintext) - (authenticate_encrypt ~adata:Cstruct.empty ~key ~nonce plaintext) + (authenticate_encrypt ~adata:"" ~key ~nonce plaintext) and short_nonce_enc _ = (* as reported in https://github.com/mirleft/ocaml-nocrypto/issues/167 *) (* valid nonce sizes for CCM are 7..13 (L can be 2..8, nonce is 15 - L)*) (* see RFC3610 Section 2.1 *) let key = of_secret (vx "000102030405060708090a0b0c0d0e0f") - and nonce = Cstruct.empty - and plaintext = Cstruct.of_string "hello" + and nonce = "" + and plaintext = "hello" in assert_raises ~msg:"CCM with short nonce raises" (Invalid_argument "Mirage_crypto: CCM: nonce length not between 7 and 13: 0") @@ -349,7 +349,7 @@ let ccm_regressions = and short_nonce_enc2 _ = let key = of_secret (vx "000102030405060708090a0b0c0d0e0f") and nonce = vx "00" - and plaintext = Cstruct.of_string "hello" + and plaintext = "hello" in assert_raises ~msg:"CCM with short nonce raises" (Invalid_argument "Mirage_crypto: CCM: nonce length not between 7 and 13: 1") @@ -357,7 +357,7 @@ let ccm_regressions = and short_nonce_enc3 _ = let key = of_secret (vx "000102030405060708090a0b0c0d0e0f") and nonce = vx "000102030405" - and plaintext = Cstruct.of_string "hello" + and plaintext = "hello" in assert_raises ~msg:"CCM with short nonce raises" (Invalid_argument "Mirage_crypto: CCM: nonce length not between 7 and 13: 6") @@ -365,7 +365,7 @@ let ccm_regressions = and long_nonce_enc _ = let key = of_secret (vx "000102030405060708090a0b0c0d0e0f") and nonce = vx "000102030405060708090a0b0c0d" - and plaintext = Cstruct.of_string "hello" + and plaintext = "hello" in assert_raises ~msg:"CCM with short nonce raises" (Invalid_argument "Mirage_crypto: CCM: nonce length not between 7 and 13: 14") @@ -374,23 +374,23 @@ let ccm_regressions = (* as reported in https://github.com/mirleft/ocaml-nocrypto/issues/168 *) let key = of_secret (vx "000102030405060708090a0b0c0d0e0f") and nonce = vx "0001020304050607" - and adata = Cstruct.of_string "hello" - and p = Cstruct.empty + and adata = "hello" + and p = "" in let cipher = authenticate_encrypt ~adata ~key ~nonce p in match authenticate_decrypt ~key ~nonce ~adata cipher with - | Some x -> assert_cs_equal ~msg:"CCM decrypt of empty message" p x + | Some x -> assert_oct_equal ~msg:"CCM decrypt of empty message" p x | None -> assert_failure "decryption broken" and long_adata _ = let key = of_secret (vx "000102030405060708090a0b0c0d0e0f") and nonce = vx "0001020304050607" - and plaintext = Cstruct.of_string "hello" + and plaintext = "hello" (* [adata] is greater than [1 lsl 16 - 1 lsl 8] *) - and adata = Cstruct.create 65280 + and adata = String.make 65280 '\x00' and expected = vx "6592169e946f98973bc06d080f7c9dbb493a536f8a" in let cipher = authenticate_encrypt ~adata ~key ~nonce plaintext in - assert_cs_equal ~msg:"CCM encrypt of >=65280 adata" expected cipher + assert_oct_equal ~msg:"CCM encrypt of >=65280 adata" expected cipher in [ test_case no_vs_empty_ad ; @@ -406,7 +406,7 @@ let gcm_regressions = let open Cipher_block.AES.GCM in let msg = vx "000102030405060708090a0b0c0d0e0f" in let key = of_secret msg - and nonce = Cstruct.empty + and nonce = "" in let nonce_zero_length_enc _ = (* reported in https://github.com/mirleft/ocaml-nocrypto/issues/169 *) @@ -417,35 +417,20 @@ let gcm_regressions = assert_raises ~msg:"GCM with nonce of 0" (Invalid_argument "Mirage_crypto: GCM: invalid nonce of length 0") (fun () -> authenticate_decrypt ~key ~nonce msg) - and unaligned _ = - let key = of_secret (vx "00000000000000000000000000000000") - and c = vx "0388dace60b6a392f328c2b971b2fe78" - and p = vx "00000000000000000000000000000000" - and nonce = vx "000000000000000000000000" - and t = vx "ab6e47d42cec13bdf53a67b21257bddf" - in - let cipher = Cstruct.shift (Cstruct.concat [ Cstruct.create 1 ; c ; t ]) 1 in - let auth_dec ~key ~nonce cipher = match authenticate_decrypt ~key ~nonce cipher with - | None -> assert_failure "GCM decryption failure" - | Some x -> x - in - assert_cs_equal ~msg:"GCM with unaligned msg" - (auth_dec ~key ~nonce cipher) p in [ test_case nonce_zero_length_enc ; test_case nonce_zero_length_dec ; - test_case unaligned ; ] let chacha20_cases = let case msg ?ctr ~key ~nonce ?(input = String.make 128 '\000') output = let key = Chacha20.of_secret (vx key) - and nonce = vx_str nonce - and output = vx_str output + and nonce = vx nonce + and output = vx output in - assert_str_equal ~msg (Chacha20.crypt ~key ~nonce ?ctr input) output + assert_oct_equal ~msg (Chacha20.crypt ~key ~nonce ?ctr input) output in let rfc8439_input = "Ladies and Gentlemen of the class of '99: If I could offer you only one tip for the future, sunscreen would be it." in let rfc8439_test_2_4_2 _ = @@ -477,17 +462,13 @@ let chacha20_cases = 61 16 1a e1 0b 59 4f 09 e2 6a 7e 90 2e cb d0 60 06 91|} in - assert_cs_equal ~msg:"Chacha20/Poly1305 RFC 8439 2.8.2 encrypt" - (Chacha20.authenticate_encrypt ~key ~nonce ~adata (Cstruct.of_string rfc8439_input)) + assert_oct_equal ~msg:"Chacha20/Poly1305 RFC 8439 2.8.2 encrypt" + (Chacha20.authenticate_encrypt ~key ~nonce ~adata rfc8439_input) output; - assert_cs_equal ~msg:"Chacha20/Poly1305 RFC 8439 2.8.2 decrypt" + assert_oct_equal ~msg:"Chacha20/Poly1305 RFC 8439 2.8.2 decrypt" (match Chacha20.authenticate_decrypt ~key ~nonce ~adata output with | Some cs -> cs | None -> assert_failure "Chacha20/poly1305 decryption broken") - (Cstruct.of_string rfc8439_input); - let input = Cstruct.(shift (append (create 16) (Cstruct.of_string rfc8439_input)) 16) in - assert_cs_equal ~msg:"Chacha20/Poly1305 RFC 8439 2.8.2 encrypt 2" - (Chacha20.authenticate_encrypt ~key ~nonce ~adata input) - output; + rfc8439_input; in (* from https://tools.ietf.org/html/draft-strombergson-chacha-test-vectors-01 *) let case ~key ~nonce ~output0 ~output1 _ = @@ -693,92 +674,92 @@ let chacha20_cases = ] let poly1305_rfc8439_2_5_2 _ = - let key = vx_str "85d6be7857556d337f4452fe42d506a80103808afb0db2fd4abff6af4149f51b" + let key = vx "85d6be7857556d337f4452fe42d506a80103808afb0db2fd4abff6af4149f51b" and data = "Cryptographic Forum Research Group" - and output = vx_str "a8061dc1305136c6c22b8baf0c0127a9" + and output = vx "a8061dc1305136c6c22b8baf0c0127a9" in - assert_str_equal ~msg:"poly 1305 RFC8439 Section 2.5.2" + assert_oct_equal ~msg:"poly 1305 RFC8439 Section 2.5.2" (Poly1305.mac ~key data) output let empty_cases _ = let open Cipher_block in - let plain = Cstruct.empty - and cipher = Cstruct.empty + let plain = "" + and cipher = "" in (* 3DES ECB CBC CTR *) Array.iter (fun key_size -> - let key = DES.ECB.of_secret (Cstruct.create key_size) in - assert_cs_equal ~msg:"DES ECB encrypt" cipher (DES.ECB.encrypt ~key plain) ; - assert_cs_equal ~msg:"DES ECB decrypt" plain (DES.ECB.decrypt ~key cipher)) + let key = DES.ECB.of_secret (String.make key_size '\x00') in + assert_oct_equal ~msg:"DES ECB encrypt" cipher (DES.ECB.encrypt ~key plain) ; + assert_oct_equal ~msg:"DES ECB decrypt" plain (DES.ECB.decrypt ~key cipher)) DES.ECB.key_sizes ; Array.iter (fun key_size -> - let key = DES.CBC.of_secret (Cstruct.create key_size) - and iv = Cstruct.create DES.CBC.block_size + let key = DES.CBC.of_secret (String.make key_size '\x00') + and iv = String.make DES.CBC.block_size '\x00' in - assert_cs_equal ~msg:"DES CBC encrypt" cipher (DES.CBC.encrypt ~key ~iv plain) ; - assert_cs_equal ~msg:"DES CBC decrypt" plain (DES.CBC.decrypt ~key ~iv cipher)) + assert_oct_equal ~msg:"DES CBC encrypt" cipher (DES.CBC.encrypt ~key ~iv plain) ; + assert_oct_equal ~msg:"DES CBC decrypt" plain (DES.CBC.decrypt ~key ~iv cipher)) DES.CBC.key_sizes ; Array.iter (fun key_size -> - let key = DES.CTR.of_secret (Cstruct.create key_size) - and ctr = DES.CTR.ctr_of_cstruct (Cstruct.create DES.CTR.block_size) + let key = DES.CTR.of_secret (String.make key_size '\x00') + and ctr = DES.CTR.ctr_of_octets (String.make DES.CTR.block_size '\x00') in - assert_cs_equal ~msg:"DES CTR encrypt" cipher (DES.CTR.encrypt ~key ~ctr plain) ; - assert_cs_equal ~msg:"DES CTR decrypt" plain (DES.CTR.decrypt ~key ~ctr cipher)) + assert_oct_equal ~msg:"DES CTR encrypt" cipher (DES.CTR.encrypt ~key ~ctr plain) ; + assert_oct_equal ~msg:"DES CTR decrypt" plain (DES.CTR.decrypt ~key ~ctr cipher)) DES.CTR.key_sizes ; (* AES ECB CBC CTR GCM CCM16 *) Array.iter (fun key_size -> - let key = AES.ECB.of_secret (Cstruct.create key_size) in - assert_cs_equal ~msg:"AES ECB encrypt" cipher (AES.ECB.encrypt ~key plain) ; - assert_cs_equal ~msg:"AES ECB decrypt" plain (AES.ECB.decrypt ~key cipher)) + let key = AES.ECB.of_secret (String.make key_size '\x00') in + assert_oct_equal ~msg:"AES ECB encrypt" cipher (AES.ECB.encrypt ~key plain) ; + assert_oct_equal ~msg:"AES ECB decrypt" plain (AES.ECB.decrypt ~key cipher)) AES.ECB.key_sizes ; Array.iter (fun key_size -> - let key = AES.CBC.of_secret (Cstruct.create key_size) - and iv = Cstruct.create AES.CBC.block_size + let key = AES.CBC.of_secret (String.make key_size '\x00') + and iv = String.make AES.CBC.block_size '\x00' in - assert_cs_equal ~msg:"AES CBC encrypt" cipher (AES.CBC.encrypt ~key ~iv plain) ; - assert_cs_equal ~msg:"AES CBC decrypt" plain (AES.CBC.decrypt ~key ~iv cipher)) + assert_oct_equal ~msg:"AES CBC encrypt" cipher (AES.CBC.encrypt ~key ~iv plain) ; + assert_oct_equal ~msg:"AES CBC decrypt" plain (AES.CBC.decrypt ~key ~iv cipher)) AES.CBC.key_sizes ; Array.iter (fun key_size -> - let key = AES.CTR.of_secret (Cstruct.create key_size) - and ctr = AES.CTR.ctr_of_cstruct (Cstruct.create AES.CTR.block_size) + let key = AES.CTR.of_secret (String.make key_size '\x00') + and ctr = AES.CTR.ctr_of_octets (String.make AES.CTR.block_size '\x00') in - assert_cs_equal ~msg:"AES CTR encrypt" cipher (AES.CTR.encrypt ~key ~ctr plain) ; - assert_cs_equal ~msg:"AES CTR decrypt" plain (AES.CTR.decrypt ~key ~ctr cipher)) + assert_oct_equal ~msg:"AES CTR encrypt" cipher (AES.CTR.encrypt ~key ~ctr plain) ; + assert_oct_equal ~msg:"AES CTR decrypt" plain (AES.CTR.decrypt ~key ~ctr cipher)) AES.CTR.key_sizes ; Array.iter (fun key_size -> - let key = AES.CCM16.of_secret (Cstruct.create key_size) in + let key = AES.CCM16.of_secret (String.make key_size '\x00') in let test_one nonce = let c, tag = AES.CCM16.authenticate_encrypt_tag ~key ~nonce plain in - assert_cs_equal ~msg:"AES CCM16 encrypt" cipher c ; + assert_oct_equal ~msg:"AES CCM16 encrypt" cipher c ; match AES.CCM16.authenticate_decrypt_tag ~key ~nonce ~tag cipher with | None -> assert false - | Some p -> assert_cs_equal ~msg:"AES CCM16 decrypt" plain p + | Some p -> assert_oct_equal ~msg:"AES CCM16 decrypt" plain p in - test_one (Cstruct.create 7); - test_one (Cstruct.create 8); - test_one (Cstruct.create 13)) + test_one (String.make 7 '\x00'); + test_one (String.make 8 '\x00'); + test_one (String.make 13 '\x00')) AES.CCM16.key_sizes ; (* ChaCha20 *) Array.iter (fun key_size -> - let key = Chacha20.of_secret (Cstruct.create key_size) in + let key = Chacha20.of_secret (String.make key_size '\x00') in let test_one nonce = let c, tag = Chacha20.authenticate_encrypt_tag ~key ~nonce plain in - assert_cs_equal ~msg:"Chacha20 encrypt" cipher c ; + assert_oct_equal ~msg:"Chacha20 encrypt" cipher c ; match Chacha20.authenticate_decrypt_tag ~key ~nonce ~tag cipher with | None -> assert false - | Some p -> assert_cs_equal ~msg:"Chacha20 decrypt" plain p + | Some p -> assert_oct_equal ~msg:"Chacha20 decrypt" plain p in - test_one (Cstruct.create 8); + test_one (String.make 8 '\x00'); if key_size = 32 then - test_one (Cstruct.create 12)) + test_one (String.make 12 '\x00')) [| 16 ; 32 |] ; (* ARC4 *) - let key = Cipher_stream.ARC4.of_secret (Cstruct.create 16) in - assert_cs_equal ~msg:"ARC4 encrypt" cipher (Cipher_stream.ARC4.(encrypt ~key plain).message) ; - assert_cs_equal ~msg:"ARC4 decrypt" plain (Cipher_stream.ARC4.(decrypt ~key cipher).message) + let key = Cipher_stream.ARC4.of_secret (String.make 16 '\x00') in + assert_oct_equal ~msg:"ARC4 encrypt" cipher (Cipher_stream.ARC4.(encrypt ~key plain).message) ; + assert_oct_equal ~msg:"ARC4 decrypt" plain (Cipher_stream.ARC4.(decrypt ~key cipher).message) let suite = [ "AES-ECB" >::: [ "SP 300-38A" >::: aes_ecb_cases ] ; diff --git a/tests/test_common.ml b/tests/test_common.ml index a64b0b96..10c19c51 100644 --- a/tests/test_common.ml +++ b/tests/test_common.ml @@ -34,11 +34,6 @@ let of_hex ?(skip_ws = true) s = assert (leftover = None); String.init (List.length chars) (fun i -> char_of_int (List.nth chars i)) -let rec blocks_of_cs n cs = - let open Cstruct in - if length cs <= n then [ cs ] - else sub cs 0 n :: blocks_of_cs n (shift cs n) - let rec range a b = if a > b then [] else a :: range (succ a) b @@ -53,14 +48,9 @@ let eq_opt eq a b = match (a, b) with | (Some x, Some y) -> eq x y | _ -> false -let assert_cs_equal ?msg = - assert_equal ~cmp:Cstruct.equal ?msg - ~pp_diff:(pp_diff Cstruct.hexdump_pp) - -let pp_octets pp ppf (a, b) = - pp Cstruct.hexdump_pp ppf (Cstruct.of_string a, Cstruct.of_string b) +let pp_octets pp = pp (Ohex.pp_hexdump ()) -let assert_str_equal ?msg = +let assert_oct_equal ?msg = assert_equal ~cmp:String.equal ?msg ~pp_diff:(pp_octets pp_diff) let iter_list xs f = List.iter f xs @@ -70,23 +60,9 @@ let cases_of f = let any _ = true -let vx = Cstruct.of_hex - -let vx_str data = Cstruct.to_string (Cstruct.of_hex data) +let vx = Ohex.decode let f1_eq ?msg f (a, b) _ = - assert_cs_equal ?msg (f (vx a)) (vx b) - -let f1_opt_eq ?msg f (a, b) _ = - let maybe = function None -> None | Some h -> Some (vx h) in - let (a, b) = vx a, maybe b in - let eq_opt eq a b = match (a, b) with - | (Some x, Some y) -> eq x y - | (None , None ) -> true - | _ -> false - in - assert_equal b (f a) ?msg - ~cmp:(eq_opt Cstruct.equal) - ~pp_diff:(pp_diff (pp_opt Cstruct.hexdump_pp)) + assert_oct_equal ?msg (f (vx a)) (vx b) let f2_eq ?msg f (a, b, c) = f1_eq ?msg (f (vx a)) (b, c) diff --git a/tests/test_dh.ml b/tests/test_dh.ml index f43dfc07..20465d16 100644 --- a/tests/test_dh.ml +++ b/tests/test_dh.ml @@ -15,12 +15,12 @@ let dh_selftest ~bits n = ~cmp:(eq_opt String.equal) ~pp_diff:(pp_diff (fun ppf -> function | None -> Format.fprintf ppf "None" - | Some a -> Format.fprintf ppf "Some(%a)" Cstruct.hexdump_pp (Cstruct.of_string a))) + | Some a -> Format.fprintf ppf "Some(%a)" (Ohex.pp_hexdump ()) a)) ~msg:"shared secret" let dh_shared_0 = "shared_0" >:: fun _ -> - let gy = vx_str + let gy = vx "14 ac e2 c0 9c c0 0c 25 89 71 b2 d0 1c 94 58 21 02 23 b7 23 ec 3e 24 e5 a3 c2 fd 16 cc 49 f0 e2 87 62 a5 a0 73 f5 de 5b 9b eb c3 60 0b a4 03 38 @@ -33,7 +33,7 @@ let dh_shared_0 = a5 23 69 38 7e ec b5 fc 4b 89 42 c4 32 fa e5 58 6f 39 5d a7 4e cd b5 da dc 1e 52 fe a4 33 72 c1 82 48 8a 5b c1 44 bc 60 9b 38 5b 80 5f 44 14 93" - and s = vx_str + and s = vx "f9 47 87 95 d2 a1 6d d1 7c c8 a9 c0 71 28 a2 82 71 95 7e 79 87 0b fc 34 a2 42 ec 42 ac cc 42 81 7b f6 c4 f5 80 a9 70 e3 35 93 9b a3 21 81 a4 e3 @@ -46,7 +46,7 @@ let dh_shared_0 = 29 22 63 6e bb 1a 7f 93 bd 98 db 20 94 f8 f0 2e db ce 9d 79 db b9 a7 41 5f e5 29 a2 31 f8 e2 c3 30 6a 09 f2 16 a7 30 8c 2f 36 7b 71 99 1e 28 54" - and shared = vx_str + and shared = vx "a7 40 0d eb f0 4b 2b ec cb 90 3c 55 2d 3c 17 63 b2 4b 4e 1a ff 1e a0 24 c6 56 e3 5e 44 7b d0 01 ef b3 6b 57 20 0e 15 95 b1 53 1a 83 16 3a b1 61 @@ -64,7 +64,7 @@ let dh_shared_0 = match Dh.(shared (fst (key_of_secret grp ~s)) gy) with | None -> assert_failure "degenerate shared secret" | Some shared' -> - assert_str_equal ~msg:"shared secret" shared shared' + assert_oct_equal ~msg:"shared secret" shared shared' let suite = [ dh_selftest ~bits:16 1000 ; diff --git a/tests/test_dsa.ml b/tests/test_dsa.ml index cd1c3e43..eca2e557 100644 --- a/tests/test_dsa.ml +++ b/tests/test_dsa.ml @@ -15,19 +15,19 @@ open Test_common let dsa_test ~priv ~msg ?k ~r ~s ~hash _ = let hmsg = Digestif.(digest_string hash msg |> to_raw_string hash) in let (r', s') = Dsa.sign ~mask:`No ~key:priv ?k hmsg in - assert_str_equal ~msg:"computed r" r r' ; - assert_str_equal ~msg:"computed s" s s' ; + assert_oct_equal ~msg:"computed r" r r' ; + assert_oct_equal ~msg:"computed s" s s' ; (* now with masking *) let (r', s') = Dsa.sign ~key:priv ?k hmsg in - assert_str_equal ~msg:"computed r (masked)" r r' ; - assert_str_equal ~msg:"computed s (masked)" s s' ; + assert_oct_equal ~msg:"computed r (masked)" r r' ; + assert_oct_equal ~msg:"computed s (masked)" s s' ; let pub = Dsa.pub_of_priv priv in assert_bool "verify of given r, s" (Dsa.verify ~key:pub (r, s) hmsg) ; assert_bool "verify of computed r, s" (Dsa.verify ~key:pub (r', s') hmsg) -let params ~p ~q ~g = vx_str p, vx_str q, vx_str g +let params ~p ~q ~g = vx p, vx q, vx g let priv_of f ~p ~q ~gg ~x ~y = match Dsa.priv ~fips:true ~p:(f p) ~q:(f q) ~gg:(f gg) ~x:(f x) ~y:(f y) () with @@ -35,14 +35,14 @@ let priv_of f ~p ~q ~gg ~x ~y = | Error (`Msg m) -> invalid_arg "bad DSA private key %s" m let priv_of_cs = priv_of Z_extra.of_octets_be -let priv_of_hex = priv_of (fun cs -> vx_str cs |> Z_extra.of_octets_be) +let priv_of_hex = priv_of (fun cs -> vx cs |> Z_extra.of_octets_be) let case_of ~domain ~hash ~x ~y ~k ~r ~s ~msg = let (p, q, gg) = domain in - let priv = priv_of_cs ~p ~q ~gg ~x:(vx_str x) ~y:(vx_str y) - and (r, s) = vx_str r, vx_str s - and k = Z_extra.of_octets_be (vx_str k) - and msg = vx_str msg in + let priv = priv_of_cs ~p ~q ~gg ~x:(vx x) ~y:(vx y) + and (r, s) = vx r, vx s + and k = Z_extra.of_octets_be (vx k) + and msg = vx msg in dsa_test ~priv ~msg ~k ~r ~s ~hash let sha1_cases = @@ -2191,7 +2191,7 @@ let test_rfc6979 (type a) ~priv ~msg ~(hash: a Digestif.hash) ~k ~r ~s _ = let module H = (val (Digestif.module_of hash)) in let module K = Dsa.K_gen (H) in K.generate ~key:priv h1 in - assert_str_equal + assert_oct_equal ~msg:"computed k" k (Z_extra.to_octets_be ~size:(Z.numbits priv.Dsa.q // 8) k') ; dsa_test ~priv ~msg ~k:k' ~r ~s ~hash () @@ -2216,7 +2216,7 @@ let rfc6979_dsa_1024 = in let case ~msg ~hash ~k ~r ~s = - test_rfc6979 ~priv ~msg ~k:(vx_str k) ~r:(vx_str r) ~s:(vx_str s) ~hash + test_rfc6979 ~priv ~msg ~k:(vx k) ~r:(vx r) ~s:(vx s) ~hash in [ case ~msg:"sample" ~hash:Digestif.sha1 ~k:"7BDB6B0FF756E1BB5D53583EF979082F9AD5BD5B" @@ -2300,7 +2300,7 @@ let rfc6979_dsa_2048 = in let case ~msg ~hash ~k ~r ~s = - test_rfc6979 ~priv ~msg ~k:(vx_str k) ~r:(vx_str r) ~s:(vx_str s) ~hash + test_rfc6979 ~priv ~msg ~k:(vx k) ~r:(vx r) ~s:(vx s) ~hash in [ case ~hash:Digestif.sha1 ~msg:"sample" ~k:"888FA6F7738A41BDC9846466ABDB8174C0338250AE50CE955CA16230F9CBD53E" diff --git a/tests/test_eio_entropy_collection.ml b/tests/test_eio_entropy_collection.ml index 15e69af3..b2764517 100644 --- a/tests/test_eio_entropy_collection.ml +++ b/tests/test_eio_entropy_collection.ml @@ -8,13 +8,12 @@ module Printing_rng = struct let pools = 1 let reseed ~g:_ data = - Format.printf "reseeding: %a@.%!" Cstruct.hexdump_pp (Cstruct.of_string data) + Format.printf "reseeding:@.%a@.%!" (Ohex.pp_hexdump ()) data let accumulate ~g:_ source = let print data = Format.printf "accumulate: (src: %a) %a@.%!" - Mirage_crypto_rng.Entropy.pp_source source Cstruct.hexdump_pp - (Cstruct.of_string data) + Mirage_crypto_rng.Entropy.pp_source source Ohex.pp data in `Acc print end diff --git a/tests/test_entropy.ml b/tests/test_entropy.ml index 29f1e5e3..f0984d38 100644 --- a/tests/test_entropy.ml +++ b/tests/test_entropy.ml @@ -13,7 +13,7 @@ let cpu_bootstrap_check () = try let data' = cpu_rng_bootstrap 1 in if String.equal !data data' then begin - Cstruct.hexdump (Cstruct.of_string data'); + Ohex.pp Format.std_formatter data'; failwith ("same data from CPU bootstrap at " ^ string_of_int i); end; data := data' @@ -24,7 +24,7 @@ let whirlwind_bootstrap_check () = for i = 0 to 10 do let data' = Mirage_crypto_rng.Entropy.whirlwind_bootstrap 1 in if String.equal !data data' then begin - Cstruct.hexdump (Cstruct.of_string data'); + Ohex.pp Format.std_formatter data'; failwith ("same data from whirlwind bootstrap at " ^ string_of_int i); end; data := data' @@ -34,7 +34,7 @@ let timer_check () = for i = 0 to 10 do let data' = Mirage_crypto_rng.Entropy.interrupt_hook () () in if String.equal !data data' then begin - Cstruct.hexdump (Cstruct.of_string data'); + Ohex.pp Format.std_formatter data'; failwith ("same data from timer at " ^ string_of_int i); end; data := data' diff --git a/tests/test_entropy_collection.ml b/tests/test_entropy_collection.ml index 4905af8a..bed653b2 100644 --- a/tests/test_entropy_collection.ml +++ b/tests/test_entropy_collection.ml @@ -10,13 +10,12 @@ module Printing_rng = struct let generate_into ~g:_ _buf ~off:_ _len = assert false let reseed ~g:_ data = - Format.printf "reseeding: %a@.%!" Cstruct.hexdump_pp (Cstruct.of_string data) + Format.printf "reseeding:@.%a@.%!" (Ohex.pp_hexdump ()) data let accumulate ~g:_ source = let print data = Format.printf "accumulate: (src: %a) %a@.%!" - Mirage_crypto_rng.Entropy.pp_source source Cstruct.hexdump_pp - (Cstruct.of_string data) + Mirage_crypto_rng.Entropy.pp_source source Ohex.pp data in `Acc print diff --git a/tests/test_entropy_collection_async.ml b/tests/test_entropy_collection_async.ml index f2beeccb..d2331249 100644 --- a/tests/test_entropy_collection_async.ml +++ b/tests/test_entropy_collection_async.ml @@ -11,13 +11,12 @@ module Printing_rng = struct let generate_into ~g:_ _buf ~off:_ _len = assert false let reseed ~g:_ data = - Format.printf "reseeding: %a@.%!" Cstruct.hexdump_pp (Cstruct.of_string data) + Format.printf "reseeding:@.%a@.%!" (Ohex.pp_hexdump ()) data let accumulate ~g:_ source = let print data = Format.printf "accumulate: (src: %a) %a@.%!" - Mirage_crypto_rng.Entropy.pp_source source Cstruct.hexdump_pp - (Cstruct.of_string data) + Mirage_crypto_rng.Entropy.pp_source source Ohex.pp data in `Acc print diff --git a/tests/test_numeric.ml b/tests/test_numeric.ml index 5dc2878d..d5eca6ed 100644 --- a/tests/test_numeric.ml +++ b/tests/test_numeric.ml @@ -17,7 +17,7 @@ let n_decode_reencode_selftest ~typ ~bytes n = typ ^ " selftest" >:: times ~n @@ fun _ -> let cs = Mirage_crypto_rng.generate bytes in let cs' = Z_extra.(to_octets_be ~size:bytes @@ of_octets_be cs) in - assert_str_equal cs cs' + assert_oct_equal cs cs' let random_n_selftest ~typ n bounds = typ ^ " selftest" >::: ( diff --git a/tests/test_random_runner.ml b/tests/test_random_runner.ml index 4bb7fe37..092100c2 100644 --- a/tests/test_random_runner.ml +++ b/tests/test_random_runner.ml @@ -15,57 +15,61 @@ let sample arr = let ecb_selftest (m : (module Cipher_block.S.ECB)) n = let module C = ( val m ) in "selftest" >:: times ~n @@ fun _ -> - let data = Cstruct.of_string (Mirage_crypto_rng.generate (C.block_size * 8)) - and key = C.of_secret @@ Cstruct.of_string (Mirage_crypto_rng.generate (sample C.key_sizes)) in + let data = Mirage_crypto_rng.generate (C.block_size * 8) + and key = C.of_secret @@ Mirage_crypto_rng.generate (sample C.key_sizes) in let data' = C.( data |> encrypt ~key |> encrypt ~key |> decrypt ~key |> decrypt ~key ) in - assert_cs_equal ~msg:"ecb mismatch" data data' + assert_oct_equal ~msg:"ecb mismatch" data data' let cbc_selftest (m : (module Cipher_block.S.CBC)) n = let module C = ( val m ) in "selftest" >:: times ~n @@ fun _ -> - let data = Cstruct.of_string (Mirage_crypto_rng.generate (C.block_size * 8)) - and iv = Cstruct.of_string (Mirage_crypto_rng.generate C.block_size) - and key = C.of_secret @@ Cstruct.of_string (Mirage_crypto_rng.generate (sample C.key_sizes)) in - assert_cs_equal ~msg:"CBC e->e->d->d" data + let data = Mirage_crypto_rng.generate (C.block_size * 8) + and iv = Mirage_crypto_rng.generate C.block_size + and key = C.of_secret @@ Mirage_crypto_rng.generate (sample C.key_sizes) in + assert_oct_equal ~msg:"CBC e->e->d->d" data C.( data |> encrypt ~key ~iv |> encrypt ~key ~iv |> decrypt ~key ~iv |> decrypt ~key ~iv ); - let (d1, d2) = Cstruct.split data (C.block_size * 4) in - assert_cs_equal ~msg:"CBC chain" + let (d1, d2) = + String.sub data 0 (C.block_size * 4), + String.sub data (C.block_size * 4) (String.length data - C.block_size * 4) + in + assert_oct_equal ~msg:"CBC chain" C.(encrypt ~key ~iv data) C.( let e1 = encrypt ~key ~iv d1 in - Cstruct.append e1 (encrypt ~key ~iv:(C.next_iv ~iv e1) d2) ) + e1 ^ encrypt ~key ~iv:(next_iv ~iv e1) d2) let ctr_selftest (m : (module Cipher_block.S.CTR)) n = let module M = (val m) in let bs = M.block_size in "selftest" >:: times ~n @@ fun _ -> - let key = M.of_secret @@ Cstruct.of_string (Mirage_crypto_rng.generate (sample M.key_sizes)) - and ctr = Mirage_crypto_rng.generate bs |> Cstruct.of_string |> M.ctr_of_cstruct - and data = Cstruct.of_string Mirage_crypto_rng.(generate @@ bs + Randomconv.int ~bound:(20 * bs) Mirage_crypto_rng.generate) in + let key = M.of_secret @@ Mirage_crypto_rng.generate (sample M.key_sizes) + and ctr = Mirage_crypto_rng.generate bs |> M.ctr_of_octets + and data = Mirage_crypto_rng.(generate @@ bs + Randomconv.int ~bound:(20 * bs) Mirage_crypto_rng.generate) in let enc = M.encrypt ~key ~ctr data in let dec = M.decrypt ~key ~ctr enc in - assert_cs_equal ~msg:"CTR e->d" data dec; + assert_oct_equal ~msg:"CTR e->d" data dec; let (d1, d2) = - Cstruct.split data @@ bs * Randomconv.int ~bound:(Cstruct.length data / bs) Mirage_crypto_rng.generate in - assert_cs_equal ~msg:"CTR chain" enc @@ - Cstruct.append (M.encrypt ~key ~ctr d1) - (M.encrypt ~key ~ctr:(M.next_ctr ~ctr d1) d2) + let s = bs * Randomconv.int ~bound:(String.length data / bs) Mirage_crypto_rng.generate in + String.sub data 0 s, String.sub data s (String.length data - s) + in + assert_oct_equal ~msg:"CTR chain" enc @@ + M.encrypt ~key ~ctr d1 ^ M.encrypt ~key ~ctr:(M.next_ctr ~ctr d1) d2 let ctr_offsets (type c) ~zero (m : (module Cipher_block.S.CTR with type ctr = c)) n = let module M = (val m) in "offsets" >:: fun _ -> - let key = M.of_secret @@ Cstruct.of_string (Mirage_crypto_rng.generate M.key_sizes.(0)) in + let key = M.of_secret @@ Mirage_crypto_rng.generate M.key_sizes.(0) in for i = 0 to n - 1 do let ctr = match i with | 0 -> M.add_ctr zero (-1L) - | _ -> Mirage_crypto_rng.generate M.block_size |> Cstruct.of_string |> M.ctr_of_cstruct + | _ -> Mirage_crypto_rng.generate M.block_size |> M.ctr_of_octets and gap = Randomconv.int ~bound:64 Mirage_crypto_rng.generate in let s1 = M.stream ~key ~ctr ((gap + 1) * M.block_size) and s2 = M.stream ~key ~ctr:(M.add_ctr ctr (Int64.of_int gap)) M.block_size in - assert_cs_equal ~msg:"shifted stream" - Cstruct.(sub s1 (gap * M.block_size) M.block_size) s2 + assert_oct_equal ~msg:"shifted stream" + String.(sub s1 (gap * M.block_size) M.block_size) s2 done let xor_selftest n = @@ -79,9 +83,9 @@ let xor_selftest n = let x1 = Uncommon.(xor xyz (xor y z)) and x2 = Uncommon.(xor (xor z y) xyz) in - assert_str_equal ~msg:"assoc" xyz xyz' ; - assert_str_equal ~msg:"invert" x x1 ; - assert_str_equal ~msg:"commut" x1 x2 + assert_oct_equal ~msg:"assoc" xyz xyz' ; + assert_oct_equal ~msg:"invert" x x1 ; + assert_oct_equal ~msg:"commut" x1 x2 let suite = "All" >::: [ diff --git a/tests/test_rsa.ml b/tests/test_rsa.ml index 519e0a0b..dd137b80 100644 --- a/tests/test_rsa.ml +++ b/tests/test_rsa.ml @@ -93,7 +93,7 @@ let rsa_selftest ~bits n = let enc = Rsa.(encrypt ~key:(pub_of_priv key) msg) in let dec = Rsa.(decrypt ~key enc) in - assert_str_equal + assert_oct_equal ~msg:Printf.(sprintf "failed decryption with") msg dec @@ -112,7 +112,7 @@ let rsa_pkcs1_encode_selftest ~bits n = let sgn = Rsa.PKCS1.sig_encode ~key msg in match Rsa.(PKCS1.sig_decode ~key:(pub_of_priv key) sgn) with | None -> assert_failure ("unpad failure " ^ show_key_size key) - | Some dec -> assert_str_equal msg dec + | Some dec -> assert_oct_equal msg dec ~msg:("recovery failure " ^ show_key_size key) let rsa_pkcs1_sign_selftest n = @@ -135,7 +135,7 @@ let rsa_pkcs1_encrypt_selftest ~bits n = let enc = Rsa.(PKCS1.encrypt ~key:(pub_of_priv key) msg) in match Rsa.PKCS1.decrypt ~key enc with | None -> assert_failure ("unpad failure " ^ show_key_size key) - | Some dec -> assert_str_equal msg dec + | Some dec -> assert_oct_equal msg dec ~msg:("recovery failure " ^ show_key_size key) let rsa_oaep_encrypt_selftest ~bits n = @@ -150,27 +150,27 @@ let rsa_oaep_encrypt_selftest ~bits n = let enc = OAEP_MD5.encrypt ~key:(Rsa.pub_of_priv key) msg in (match OAEP_MD5.decrypt ~key enc with | None -> assert_failure "unpad failure" - | Some dec -> assert_str_equal msg dec ~msg:"recovery failure"); + | Some dec -> assert_oct_equal msg dec ~msg:"recovery failure"); let msg = Mirage_crypto_rng.generate (bits // 8 - 2 * Digestif.SHA1.digest_size - 2) in let enc = OAEP_SHA1.encrypt ~key:(Rsa.pub_of_priv key) msg in (match OAEP_SHA1.decrypt ~key enc with | None -> assert_failure "unpad failure" - | Some dec -> assert_str_equal msg dec ~msg:"recovery failure"); + | Some dec -> assert_oct_equal msg dec ~msg:"recovery failure"); let msg = Mirage_crypto_rng.generate (bits // 8 - 2 * Digestif.SHA224.digest_size - 2) in let enc = OAEP_SHA224.encrypt ~key:(Rsa.pub_of_priv key) msg in (match OAEP_SHA224.decrypt ~key enc with | None -> assert_failure "unpad failure" - | Some dec -> assert_str_equal msg dec ~msg:"recovery failure"); + | Some dec -> assert_oct_equal msg dec ~msg:"recovery failure"); let msg = Mirage_crypto_rng.generate (bits // 8 - 2 * Digestif.SHA256.digest_size - 2) in let enc = OAEP_SHA256.encrypt ~key:(Rsa.pub_of_priv key) msg in (match OAEP_SHA256.decrypt ~key enc with | None -> assert_failure "unpad failure" - | Some dec -> assert_str_equal msg dec ~msg:"recovery failure"); + | Some dec -> assert_oct_equal msg dec ~msg:"recovery failure"); let msg = Mirage_crypto_rng.generate (bits // 8 - 2 * Digestif.SHA384.digest_size - 2) in let enc = OAEP_SHA384.encrypt ~key:(Rsa.pub_of_priv key) msg in (match OAEP_SHA384.decrypt ~key enc with | None -> assert_failure "unpad failure" - | Some dec -> assert_str_equal msg dec ~msg:"recovery failure") + | Some dec -> assert_oct_equal msg dec ~msg:"recovery failure") let rsa_pss_sign_selftest ~bits n = let module Pss_sha1 = Rsa.PSS (Digestif.SHA1) in @@ -197,10 +197,10 @@ let rsa_pkcs1_cases = in let case ~hash ~msg ~sgn = test_case @@ fun _ -> - let msg = vx_str msg and sgn = vx_str sgn in + let msg = vx msg and sgn = vx sgn in let key, public = key () in Rsa.(PKCS1.sign ~hash ~key (`Message msg)) - |> assert_str_equal ~msg:"recomputing sig:" sgn ; + |> assert_oct_equal ~msg:"recomputing sig:" sgn ; Rsa.(PKCS1.verify ~hashp:any ~key:public ~signature:sgn (`Message msg)) |> assert_bool "sig verification" in @@ -250,11 +250,11 @@ let rsa_pss_cases = let case (type a) ~(hash : a Digestif.hash) ~msg ~sgn = test_case @@ fun _ -> let module H = (val Digestif.module_of hash) in let module Pss = Rsa.PSS (H) in - let msg = vx_str msg and sgn = vx_str sgn and salt = vx_str salt in + let msg = vx msg and sgn = vx sgn and salt = vx salt in let key, public = key () in let slen = String.length salt in Pss.sign ~g:(random_is salt) ~slen ~mask:`No ~key (`Message msg) - |> assert_str_equal ~msg:"recomputing sig:" sgn ; + |> assert_oct_equal ~msg:"recomputing sig:" sgn ; Pss.verify ~key:public ~slen ~signature:sgn (`Message msg) |> assert_bool "sig verification" in diff --git a/tests/wycheproof/dune b/tests/wycheproof/dune index 4aaf545a..f41d446c 100644 --- a/tests/wycheproof/dune +++ b/tests/wycheproof/dune @@ -1,6 +1,6 @@ (library (name wycheproof) - (libraries yojson ppx_deriving_yojson.runtime hex) + (libraries yojson ppx_deriving_yojson.runtime) (preprocess (pps ppx_deriving.std ppx_deriving_yojson)) (optional)) diff --git a/tests/wycheproof/wycheproof.ml b/tests/wycheproof/wycheproof.ml index 91d2c361..0ca562f8 100644 --- a/tests/wycheproof/wycheproof.ml +++ b/tests/wycheproof/wycheproof.ml @@ -4,14 +4,44 @@ let pp_json = Yojson.Safe.pretty_print type hex = string [@@deriving eq] -let pp_hex fmt s = - let (`Hex h) = Hex.of_string s in - Format.pp_print_string fmt h +let pp_hex fmt buf = + let n = String.length buf in + let bbuf = Bytes.unsafe_of_string buf in + for i = n - 1 downto 0 do + let byte = Bytes.get_uint8 bbuf i in + Format.fprintf fmt "%02x" byte + done + +let hex_of_string s = + let fold f acc str = + let st = ref acc in + String.iter (fun c -> st := f !st c) str; + !st + and digit c = + match c with + | '0'..'9' -> int_of_char c - 0x30 + | 'A'..'F' -> int_of_char c - 0x41 + 10 + | 'a'..'f' -> int_of_char c - 0x61 + 10 + | _ -> invalid_arg "bad character" + in + let out = Bytes.create (String.length s / 2) in + let _idx, leftover = + fold (fun (idx, leftover) c -> + let c = digit c in + match leftover with + | None -> idx, Some (c lsl 4) + | Some c' -> + Bytes.set_uint8 out idx (c' lor c); + succ idx, None) + (0, None) s + in + assert (leftover = None); + Bytes.unsafe_to_string out let hex_of_yojson json = let padded s = if String.length s mod 2 = 0 then s else "0" ^ s in match [%of_yojson: string] json with - | Ok s -> Ok (Hex.to_string (`Hex (padded s))) + | Ok s -> Ok (hex_of_string (padded s)) | Error _ as e -> e type test_result = Valid | Acceptable | Invalid [@@deriving show]