Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rANS (rans4x8) encoding support for CRAM block compression #330

Merged
merged 3 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
335 changes: 327 additions & 8 deletions src/cljam/io/cram/codecs/rans4x8.clj
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
(ns cljam.io.cram.codecs.rans4x8
(:require [cljam.io.util.byte-buffer :as bb]
[cljam.io.cram.itf8 :as itf8])
(:import [java.util Arrays]))
(:require [cljam.io.cram.itf8 :as itf8]
[cljam.io.util.byte-buffer :as bb])
(:import [java.nio Buffer ByteBuffer]
[java.util Arrays]))

(def ^:private byte-array-type (type (byte-array 0)))
(def ^:private int-array-type (type (int-array 0)))
Expand Down Expand Up @@ -49,9 +50,9 @@
(let [curr (aget cum-freqs i)]
(if (= start curr)
(recur (inc i) start)
(do (Arrays/fill arr start curr (byte (dec i)))
(do (Arrays/fill arr start curr (unchecked-byte (dec i)))
(recur (inc i) curr))))
(Arrays/fill arr start 4096 (byte 255))))
(Arrays/fill arr start 4096 (unchecked-byte 255))))
arr))

(defn- advance-step ^long [^long c ^long f ^long state]
Expand Down Expand Up @@ -79,7 +80,7 @@
state' (->> state
(advance-step (aget cum-freqs sym) (aget freqs sym))
(renormalize-state bb))]
(aset out i (byte sym))
(aset out i (unchecked-byte sym))
(aset states j state')))
out))

Expand Down Expand Up @@ -107,7 +108,7 @@
(advance-step (aget cfreqs sym)
(aget ^ints (aget freqs last-sym) sym))
(renormalize-state bb))]
(aset out (+ i (* j quarter)) (byte sym))
(aset out (+ i (* j quarter)) (unchecked-byte sym))
(aset states j state')
(aset last-syms j sym))))
(dotimes [i (- n-out truncated)]
Expand All @@ -120,7 +121,7 @@
(advance-step (aget cfreq sym)
(aget ^ints (aget freqs last-sym) sym))
(renormalize-state bb))]
(aset out (+ i truncated) (byte sym))
(aset out (+ i truncated) (unchecked-byte sym))
(aset states 3 state')
(aset last-syms 3 sym)))
out))
Expand All @@ -135,3 +136,321 @@
(if (zero? order)
(decode0 bb n-out)
(decode1 bb n-out))))

(defn- normalize-frequencies! [^ints freqs ^long total]
(let [tr (+ (quot (bit-shift-left 4096 31) total)
(quot (bit-shift-left 1 31) total))]
(loop [i 0, m 0, M 0, fsum 0]
(if (< i 256)
(let [f (aget freqs i)]
(if (zero? f)
(recur (inc i) m M fsum)
(let [f' (as-> (unsigned-bit-shift-right (* f tr) 31) f'
(if (zero? f') 1 f'))
fsum' (+ fsum f')]
(aset freqs i f')
(if (< m f')
(recur (inc i) f' i fsum')
(recur (inc i) m M fsum')))))
(let [f (aget freqs M)
fsum' (inc fsum)]
(if (< fsum' 4096)
(aset freqs M (+ f (- 4096 fsum')))
(aset freqs M (- f (- fsum' 4096)))))))))

(defmacro ^:private ainc!
[arr i]
{:pre [(symbol? arr) (symbol? i)]}
`(aset ~arr ~i (inc (aget ~arr ~i))))

(defn- calculate-frequencies0 ^ints [^ByteBuffer bb]
(let [total (.remaining bb)
freqs (int-array 256)]
(dotimes [_ total]
(let [b (long (bb/read-ubyte bb))]
(ainc! freqs b)))
(normalize-frequencies! freqs total)
freqs))

(definline ^:private read-ubyte-from [^ByteBuffer bb i]
`(bit-and 0xff (.get ~bb (long ~i))))

(defn- calculate-frequencies1 ^"[[I" [^ByteBuffer bb]
(let [size (.remaining bb)
^"[[I" freqs (make-array Integer/TYPE 256 256)
totals (int-array 256)
_ (loop [i 0, prev 0]
(when (< i size)
(let [b (long (bb/read-ubyte bb))
^ints fs (aget freqs prev)]
(ainc! fs b)
(ainc! totals prev)
(recur (inc i) b))))
q (unsigned-bit-shift-right size 2)
^ints f0 (aget freqs 0)
b (read-ubyte-from bb q)
_ (ainc! f0 b)
b (read-ubyte-from bb (* 2 q))
_ (ainc! f0 b)
b (read-ubyte-from bb (* 3 q))
_ (ainc! f0 b)]
(aset totals 0 (+ (aget totals 0) 3))
(dotimes [i 256]
(let [total (aget totals i)]
(when-not (zero? total)
(normalize-frequencies! (aget freqs i) total))))
freqs))

(defn- next-rle ^long [^long rle ^ints freqs ^long i ^ByteBuffer out]
(if (zero? rle)
(do (.put out (unchecked-byte i))
(if (and (> i 0) (not (zero? (aget freqs (dec i)))))
(let [rle' (loop [rle (inc i)]
(if (and (< rle 256)
(not (zero? (aget freqs rle))))
(recur (inc rle))
(- rle (inc i))))]
(.put out (unchecked-byte rle'))
rle')
rle))
(dec rle)))

(defn- encode-itf8
"Simplified version of ITF8 encoder for up to two bytes"
[^ByteBuffer out ^long n]
(if (< n 128)
(.put out (unchecked-byte n))
(do (.put out (unchecked-byte (bit-or 128 (unsigned-bit-shift-right n 8))))
(.put out (unchecked-byte (bit-and 0xff n))))))

(defn- write-frequencies0 ^long [^ByteBuffer out ^ints freqs]
(let [start (.position out)]
(loop [i 0, rle 0]
(when (< i 256)
(let [f (aget freqs i)]
(if (zero? f)
(recur (inc i) rle)
(let [rle' (next-rle rle freqs i out)]
(encode-itf8 out f)
(recur (inc i) rle'))))))
(.put out (unchecked-byte 0))
(- (.position out) start)))

(defn- write-frequencies1 ^long [^ByteBuffer out ^"[[I" freqs]
(let [start (.position out)
totals (int-array 256)
_ (dotimes [i 256]
(let [^ints fs (aget freqs i)]
(dotimes [j 256]
(aset totals i (+ (aget totals i) (aget fs j))))))]
(loop [i 0, rle-i 0]
(when (< i 256)
(if (zero? (aget totals i))
(recur (inc i) rle-i)
(let [rle-i' (next-rle rle-i totals i out)
^ints fs (aget freqs i)]
(loop [j 0, rle-j 0]
(when (< j 256)
(let [f (aget fs j)]
(if (zero? f)
(recur (inc j) rle-j)
(let [rle-j' (next-rle rle-j fs j out)]
(encode-itf8 out f)
(recur (inc j) rle-j'))))))
(.put out (unchecked-byte 0))
(recur (inc i) rle-i')))))
(.put out (unchecked-byte 0))
(- (.position out) start)))

(defprotocol ISymbolState
(init! [this start freq])
(update! [this ^ByteBuffer bb b]))

(deftype SymbolState
[^:unsynchronized-mutable ^long xmax
^:unsynchronized-mutable ^long rcp-freq
^:unsynchronized-mutable ^long bias
^:unsynchronized-mutable ^long cmpl-freq
^:unsynchronized-mutable ^long rcp-shift]
ISymbolState
(init! [_ start freq]
(let [start (long start)
freq (long freq)]
(set! xmax (* 0x80000 freq))
(set! cmpl-freq (- 0x1000 freq))
(if (< freq 2)
(do (set! rcp-freq (bit-not 0))
(set! rcp-shift 0)
(set! bias (dec (+ start 0x1000))))
(let [shift (long
(loop [shift 0]
(if (< (bit-shift-left 1 shift) freq)
(recur (inc shift))
shift)))]
(set! rcp-freq (quot (dec (+ (bit-shift-left 0x80000000 shift) freq)) freq))
(set! rcp-shift (dec shift))
(set! bias start)))
(set! rcp-shift (+ rcp-shift 32))))
(update! [_ bb r]
(let [x (long
(loop [i 2, x (long r)]
(if (or (zero? i) (< x xmax))
x
(do (.put ^ByteBuffer bb (unchecked-byte (bit-and 0xff x)))
(recur (dec i) (unsigned-bit-shift-right x 8))))))
q (unsigned-bit-shift-right (* x (bit-and 0xffffffff rcp-freq)) rcp-shift)]
(+ x bias (* q cmpl-freq)))))

(def ^:private ^:const RANS_BYTE_L 0x800000)

(defn- reverse-buffer! [^ByteBuffer bb]
(let [arr (.array bb)
offset (.arrayOffset bb)
size (.limit bb)]
(loop [i offset, j (dec (+ offset size))]
(when (< i j)
(let [t (aget arr j)]
(aset arr j (aget arr i))
(aset arr i t)
(recur (inc i) (dec j)))))))

(defn- encode-payload0 ^long [^ByteBuffer in ^objects syms ^ByteBuffer out]
(let [raw-size (.remaining in)
r (bit-and raw-size 3)
out' (.slice out)
r2 (if (= r 3)
(long (update! (aget syms (read-ubyte-from in (- raw-size r -2))) out' RANS_BYTE_L))
RANS_BYTE_L)
r1 (if (>= r 2)
(long (update! (aget syms (read-ubyte-from in (- raw-size r -1))) out' RANS_BYTE_L))
RANS_BYTE_L)
r0 (if (>= r 1)
(long (update! (aget syms (read-ubyte-from in (- raw-size r))) out' RANS_BYTE_L))
RANS_BYTE_L)]
(loop [i (bit-and raw-size (bit-not 3)), r3 RANS_BYTE_L, r2 r2, r1 r1, r0 r0]
(if (> i 0)
(let [r3' (long (update! (aget syms (read-ubyte-from in (- i 1))) out' r3))
r2' (long (update! (aget syms (read-ubyte-from in (- i 2))) out' r2))
r1' (long (update! (aget syms (read-ubyte-from in (- i 3))) out' r1))
r0' (long (update! (aget syms (read-ubyte-from in (- i 4))) out' r0))]
(recur (- i 4) r3' r2' r1' r0'))
(do (.putInt out' r3)
(.putInt out' r2)
(.putInt out' r1)
(.putInt out' r0)
(.flip ^Buffer out')
(reverse-buffer! out')
(.position ^Buffer in (.limit in))
(.limit out'))))))

(defmacro ^:private aget2 [syms i j]
`(aget ~(with-meta `(aget ~syms ~i) {:tag 'objects}) ~j))

(defn- encode-payload1 ^long [^ByteBuffer in ^objects syms ^ByteBuffer out]
(let [raw-size (.remaining in)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The spec says We do not permit Order-1 encoding of data streams smaller than 4 bytes but it seems that the current implementation actually handles short input data well.
Do you think we should warn if raw-size is less than 4?

(->> (.getBytes "ab")
     bb/make-lsb-byte-buffer
     (rans/encode 1)
     bb/make-lsb-byte-buffer
     rans/decode
     String.
     (= "ab"))
;; => true

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for pointing that out! I had overlooked that part of the CRAM specification 🙇

It seems that in htslib and htsjdk, inputs smaller than 4 bytes are encoded as Order 0, even if Order 1 encoding is specified. In practice, this implicit fallback to Order 0 for such small inputs doesn’t seem to cause significant issues, so it appears to be a reasonable approach. What do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for looking into the behavior of other tools!
As you mentioned, I think implicit falling back to Order-0 works just fine! 👍👍

q (unsigned-bit-shift-right raw-size 2)
i0 (- q 2)
i1 (- (* 2 q) 2)
i2 (- (* 3 q) 2)
l0 (if (>= (inc i0) 0) (read-ubyte-from in (inc i0)) 0)
l1 (if (>= (inc i1) 0) (read-ubyte-from in (inc i1)) 0)
l2 (if (>= (inc i2) 0) (read-ubyte-from in (inc i2)) 0)
out' (.slice out)]
(loop [i3 (- raw-size 2)
l3 (read-ubyte-from in (dec raw-size))
r3 RANS_BYTE_L]
(if (and (> i3 (- (* 4 q) 2)) (>= i3 0))
(let [c3 (read-ubyte-from in i3)
r3' (long (update! (aget2 syms c3 l3) out' r3))]
(recur (dec i3) c3 r3'))
(loop [i0 i0, i1 i1, i2 i2, i3 i3,
l0 l0, l1 l1, l2 l2, l3 l3,
r0 RANS_BYTE_L, r1 RANS_BYTE_L, r2 RANS_BYTE_L, r3 r3]
(if (>= i0 0)
(let [c0 (read-ubyte-from in i0)
c1 (read-ubyte-from in i1)
c2 (read-ubyte-from in i2)
c3 (read-ubyte-from in i3)
r3' (long (update! (aget2 syms c3 l3) out' r3))
r2' (long (update! (aget2 syms c2 l2) out' r2))
r1' (long (update! (aget2 syms c1 l1) out' r1))
r0' (long (update! (aget2 syms c0 l0) out' r0))]
(recur (dec i0) (dec i1) (dec i2) (dec i3) c0 c1 c2 c3 r0' r1' r2' r3'))
(let [r3' (long (update! (aget2 syms 0 l3) out' r3))
r2' (long (update! (aget2 syms 0 l2) out' r2))
r1' (long (update! (aget2 syms 0 l1) out' r1))
r0' (long (update! (aget2 syms 0 l0) out' r0))]
(.putInt out' r3')
(.putInt out' r2')
(.putInt out' r1')
(.putInt out' r0')
(.flip ^Buffer out')
(reverse-buffer! out')
(.position ^Buffer in (.limit in))
(.limit out'))))))))

(def ^:private ^:const PREFIX_BYTE_LEN (+ 1 4 4))

(defn- allocate-output-buffer [^long raw-size]
;; The size estimation code comes from:
;; - https://github.com/samtools/htscodecs/blob/51794289ac47455209c333182b6768f99a613948/htscodecs/rANS_static.c#L77
;; - https://github.com/samtools/htscodecs/blob/51794289ac47455209c333182b6768f99a613948/htscodecs/rANS_static.c#L410
(let [allocated-size (+ (* 1.05 raw-size)
;; upper bound of frequency table size
(* 257 257 3)
;; prefix
PREFIX_BYTE_LEN)]
(bb/allocate-lsb-byte-buffer allocated-size)))

(defn- encode0 ^long [^Buffer in ^ByteBuffer out]
(let [freqs (calculate-frequencies0 in)
syms (object-array 256)
_ (loop [i 0, total 0]
(when (< i 256)
(aset syms i (->SymbolState 0 0 0 0 0))
(let [f (aget freqs i)]
(when (> f 0)
(init! (aget syms i) total f))
(recur (inc i) (+ total f)))))
freq-table-size (write-frequencies0 out freqs)
_ (.rewind in)
compressed-data-size (encode-payload0 in syms out)]
(+ freq-table-size compressed-data-size)))

(defn- encode1 ^long [^Buffer in ^ByteBuffer out]
(let [freqs (calculate-frequencies1 in)
^objects syms (make-array SymbolState 256 256)
_ (dotimes [i 256]
(let [^ints fs (aget freqs i)]
(loop [j 0, total 0]
(when (< j 256)
(aset ^objects (aget syms i) j (->SymbolState 0 0 0 0 0))
(let [f (aget fs j)]
(when (> f 0)
(init! (aget2 syms i j) total f))
(recur (inc j) (+ total f)))))))
freq-table-size (write-frequencies1 out freqs)
_ (.rewind in)
compressed-data-size (encode-payload1 in syms out)]
(+ freq-table-size compressed-data-size)))

(defn encode
"Reads a byte sequence from the given ByteBuffer and encodes it by the rANS4x8 codec.
Returns the encoded result as a byte array."
^bytes [^long order ^ByteBuffer in]
(let [raw-size (.remaining in)
^ByteBuffer out (doto ^Buffer (allocate-output-buffer raw-size)
(.mark)
(.position PREFIX_BYTE_LEN))
;; According to the specification, Order-1 encoding cannot be applicable
;; to the input that is smaller than 4 bytes. So, in that case, the encoder
;; automatically falls back to Order-0 encoding.
order' (if (< raw-size 4) 0 order)
compressed-size (case order'
0 (encode0 in out)
1 (encode1 in out))]
(.reset ^Buffer out)
(.put out (unchecked-byte order'))
(.putInt out compressed-size)
(.putInt out raw-size)
(Arrays/copyOfRange (.array out) 0 (+ PREFIX_BYTE_LEN compressed-size))))
Loading
Loading