diff --git a/src/primitives/correction.rs b/src/primitives/correction.rs index 7918ac8a..3edc0454 100644 --- a/src/primitives/correction.rs +++ b/src/primitives/correction.rs @@ -6,11 +6,16 @@ //! equation to identify the error values, in a BCH-encoded string. //! +use core::convert::TryInto; +use core::marker::PhantomData; + use crate::primitives::decode::{ CheckedHrpstringError, ChecksumError, InvalidResidueError, SegwitHrpstringError, }; +use crate::primitives::{Field as _, FieldVec, LfsrIter, Polynomial}; #[cfg(feature = "alloc")] use crate::DecodeError; +use crate::{Checksum, Fe32}; /// **One more than** the maximum length (in characters) of a checksum which /// can be error-corrected without an allocator. @@ -57,6 +62,22 @@ pub trait CorrectableError { /// /// This is the function that implementors should implement. fn residue_error(&self) -> Option<&InvalidResidueError>; + + /// Wrapper around [`Self::residue_error`] that outputs a correction context. + /// + /// Will return None if the error is not a correctable one, or if the **alloc** + /// feature is disabled and the checksum is too large. See the documentation + /// for [`NO_ALLOC_MAX_LENGTH`] for more information. + /// + /// This is the function that users should call. + fn correction_context(&self) -> Option> { + #[cfg(not(feature = "alloc"))] + if Ck::CHECKSUM_LENGTH >= NO_ALLOC_MAX_LENGTH { + return None; + } + + self.residue_error().map(|e| Corrector { residue: e.residue(), phantom: PhantomData }) + } } impl CorrectableError for InvalidResidueError { @@ -104,3 +125,186 @@ impl CorrectableError for DecodeError { } } } + +/// An error-correction context. +pub struct Corrector { + residue: Polynomial, + phantom: PhantomData, +} + +impl Corrector { + /// Returns an iterator over the errors in the string. + /// + /// Returns `None` if it can be determined that there are too many errors to be + /// corrected. However, returning an iterator from this function does **not** + /// imply that the intended string can be determined. It only implies that there + /// is a unique closest correct string to the erroneous string, and gives + /// instructions for finding it. + /// + /// If the input string has sufficiently many errors, this unique closest correct + /// string may not actually be the intended string. + pub fn bch_errors(&self) -> Option> { + // 1. Compute all syndromes by evaluating the residue at each power of the generator. + let syndromes: FieldVec<_> = Ck::ROOT_GENERATOR + .powers_range(Ck::ROOT_EXPONENTS) + .map(|rt| self.residue.evaluate(&rt)) + .collect(); + + // 2. Use the Berlekamp-Massey algorithm to find the connection polynomial of the + // LFSR that generates these syndromes. For magical reasons this will be equal + // to the error locator polynomial for the syndrome. + let lfsr = LfsrIter::berlekamp_massey(&syndromes[..]); + let conn = lfsr.coefficient_polynomial(); + + // 3. The connection polynomial is the error locator polynomial. Use this to get + // the errors. + let max_correctable_errors = + (Ck::ROOT_EXPONENTS.end() - Ck::ROOT_EXPONENTS.start() + 1) / 2; + if conn.degree() <= max_correctable_errors { + Some(ErrorIterator { + evaluator: conn.mul_mod_x_d( + &Polynomial::from(syndromes), + Ck::ROOT_EXPONENTS.end() - Ck::ROOT_EXPONENTS.start() + 1, + ), + locator_derivative: conn.formal_derivative(), + inner: conn.find_nonzero_distinct_roots(Ck::ROOT_GENERATOR), + a: Ck::ROOT_GENERATOR, + c: *Ck::ROOT_EXPONENTS.start(), + }) + } else { + None + } + } +} + +/// An iterator over the errors in a string. +/// +/// The errors will be yielded as `(usize, Fe32)` tuples. +/// +/// The first component is a **negative index** into the string. So 0 represents +/// the last element, 1 the second-to-last, and so on. +/// +/// The second component is an element to **add to** the element at the given +/// location in the string. +/// +/// The maximum index is one less than [`Checksum::CODE_LENGTH`], regardless of the +/// actual length of the string. Therefore it is not safe to simply subtract the +/// length of the string from the returned index; you must first check that the +/// index makes sense. If the index exceeds the length of the string or implies that +/// an error occurred in the HRP, the string should simply be rejected as uncorrectable. +/// +/// Out-of-bound error locations will not occur "naturally", in the sense that they +/// will happen with extremely low probability for a string with a valid HRP and a +/// uniform error pattern. (The probability is 32^-n, where n is the size of the +/// range [`Checksum::ROOT_EXPONENTS`], so it is not neglible but is very small for +/// most checksums.) However, it is easy to construct adversarial inputs that will +/// exhibit this behavior, so you must take it into account. +/// +/// Out-of-bound error locations may occur naturally in the case of a string with a +/// corrupted HRP, because for checksumming purposes the HRP is treated as twice as +/// many field elements as characters, plus one. If the correct HRP is known, the +/// caller should fix this before attempting error correction. If it is unknown, +/// the caller cannot assume anything about the intended checksum, and should not +/// attempt error correction. +pub struct ErrorIterator { + evaluator: Polynomial, + locator_derivative: Polynomial, + inner: super::polynomial::RootIter, + a: Ck::CorrectionField, + c: usize, +} + +impl Iterator for ErrorIterator { + type Item = (usize, Fe32); + + fn next(&mut self) -> Option { + // Compute -i, which is the location we will return to the user. + let neg_i = match self.inner.next() { + None => return None, + Some(0) => 0, + Some(x) => Ck::ROOT_GENERATOR.multiplicative_order() - x, + }; + + // Forney's equation, as described in https://en.wikipedia.org/wiki/BCH_code#Forney_algorithm + // + // It is rendered as + // + // a^i evaluator(a^-i) + // e_k = - --------------------------------- + // a^(ci) locator_derivative(a^-i) + // + // where here a is `Ck::ROOT_GENERATOR`, c is the first element of the range + // `Ck::ROOT_EXPONENTS`, and both evalutor and locator_derivative are polynomials + // which are computed when constructing the ErrorIterator. + + let a_i = self.a.powi(neg_i as i64); + let a_neg_i = a_i.clone().multiplicative_inverse(); + + let num = self.evaluator.evaluate(&a_neg_i) * &a_i; + let den = a_i.powi(self.c as i64) * self.locator_derivative.evaluate(&a_neg_i); + let ret = -num / den; + match ret.try_into() { + Ok(ret) => Some((neg_i, ret)), + Err(_) => unreachable!("error guaranteed to lie in base field"), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::primitives::decode::SegwitHrpstring; + use crate::Bech32; + + #[test] + fn bech32() { + // Last x should be q + let s = "bc1qar0srrr7xfkvy5l643lydnw9re59gtzzwf5mdx"; + match SegwitHrpstring::new(s) { + Ok(_) => panic!("{} successfully, and wrongly, parsed", s), + Err(e) => { + let ctx = e.correction_context::().unwrap(); + let mut iter = ctx.bch_errors().unwrap(); + + assert_eq!(iter.next(), Some((0, Fe32::X))); + assert_eq!(iter.next(), None); + } + } + + // f should be z, 6 chars from the back. + let s = "bc1qar0srrr7xfkvy5l643lydnw9re59gtzfwf5mdq"; + match SegwitHrpstring::new(s) { + Ok(_) => panic!("{} successfully, and wrongly, parsed", s), + Err(e) => { + let ctx = e.correction_context::().unwrap(); + let mut iter = ctx.bch_errors().unwrap(); + + assert_eq!(iter.next(), Some((6, Fe32::T))); + assert_eq!(iter.next(), None); + } + } + + // 20 characters from the end there is a q which should be 3 + let s = "bc1qar0srrr7xfkvy5l64qlydnw9re59gtzzwf5mdq"; + match SegwitHrpstring::new(s) { + Ok(_) => panic!("{} successfully, and wrongly, parsed", s), + Err(e) => { + let ctx = e.correction_context::().unwrap(); + let mut iter = ctx.bch_errors().unwrap(); + + assert_eq!(iter.next(), Some((20, Fe32::_3))); + assert_eq!(iter.next(), None); + } + } + + // Two errors. + let s = "bc1qar0srrr7xfkvy5l643lydnw9re59gtzzwf5mxx"; + match SegwitHrpstring::new(s) { + Ok(_) => panic!("{} successfully, and wrongly, parsed", s), + Err(e) => { + let ctx = e.correction_context::().unwrap(); + assert!(ctx.bch_errors().is_none()); + } + } + } +} diff --git a/src/primitives/decode.rs b/src/primitives/decode.rs index 0cb78b4a..b4ca5e38 100644 --- a/src/primitives/decode.rs +++ b/src/primitives/decode.rs @@ -1021,6 +1021,17 @@ impl InvalidResidueError { pub fn matches_bech32_checksum(&self) -> bool { self.actual == Polynomial::from_residue(Bech32::TARGET_RESIDUE) } + + /// Accessor for the invalid residue, less the target residue. + /// + /// Note that because the error type is not parameterized by a checksum (it + /// holds the target residue but this doesn't help), the caller will need + /// to obtain the checksum from somewhere else in order to make use of this. + /// + /// Not public because [`Polynomial`] is a private type, and because the + /// subtraction will panic if this is called without checking has_data + /// on the FieldVecs. + pub(super) fn residue(&self) -> Polynomial { self.actual.clone() - &self.target } } #[cfg(feature = "std")] diff --git a/src/primitives/lfsr.rs b/src/primitives/lfsr.rs index 8abeb2db..8bb8b977 100644 --- a/src/primitives/lfsr.rs +++ b/src/primitives/lfsr.rs @@ -39,6 +39,9 @@ impl LfsrIter { /// Accessor for the coefficients used to compute the next element. pub fn coefficients(&self) -> &[F] { &self.coeffs.as_inner()[1..] } + /// Accessor for the coefficients used to compute the next element. + pub(super) fn coefficient_polynomial(&self) -> &Polynomial { &self.coeffs } + /// Create a minimal LFSR iterator that generates a set of initial /// contents, using Berlekamp's algorithm. /// diff --git a/src/primitives/polynomial.rs b/src/primitives/polynomial.rs index 04860020..211f5df7 100644 --- a/src/primitives/polynomial.rs +++ b/src/primitives/polynomial.rs @@ -2,7 +2,7 @@ //! Polynomials over Finite Fields -use core::{fmt, iter, ops, slice}; +use core::{cmp, fmt, iter, ops, slice}; use super::checksum::PackedFe32; use super::{ExtensionField, Field, FieldVec}; @@ -26,7 +26,7 @@ impl Eq for Polynomial {} impl Polynomial { pub fn from_residue(residue: R) -> Self { - (0..R::WIDTH).rev().map(|i| Fe32(residue.unpack(i))).collect() + (0..R::WIDTH).map(|i| Fe32(residue.unpack(i))).collect() } } impl Polynomial { @@ -70,7 +70,7 @@ impl Polynomial { /// Panics if [`Self::has_data`] is false. pub fn iter(&self) -> slice::Iter { self.assert_has_data(); - self.inner.iter() + self.inner[..self.degree() + 1].iter() } /// The leading term of the polynomial. @@ -89,6 +89,11 @@ impl Polynomial { /// factor of the polynomial. pub fn zero_is_root(&self) -> bool { self.inner.is_empty() || self.leading_term() == F::ZERO } + /// Computes the formal derivative of the polynomial + pub fn formal_derivative(&self) -> Self { + self.iter().enumerate().map(|(n, fe)| fe.muli(n as i64)).skip(1).collect() + } + /// Helper function to add leading 0 terms until the polynomial has a specified /// length. fn zero_pad_up_to(&mut self, len: usize) { @@ -128,6 +133,38 @@ impl Polynomial { } } + /// Evaluate the polynomial at a given element. + pub fn evaluate>(&self, elem: &E) -> E { + let mut res = E::ZERO; + for fe in self.iter().rev() { + res *= elem; + res += E::from(fe.clone()); + } + res + } + + /// Multiplies two polynomials modulo x^d, for some given `d`. + /// + /// Can be used to simply multiply two polynomials, by passing `usize::MAX` or + /// some other suitably large number as `d`. + pub fn mul_mod_x_d(&self, other: &Self, d: usize) -> Self { + if d == 0 { + return Self { inner: FieldVec::new() }; + } + + let sdeg = self.degree(); + let odeg = other.degree(); + + let convolution_product = |exp: usize| { + let sidx = exp.saturating_sub(sdeg); + let eidx = cmp::min(exp, odeg); + (sidx..=eidx).map(|i| self.inner[exp - i].clone() * &other.inner[i]).sum() + }; + + let max_n = cmp::min(sdeg + odeg + 1, d - 1); + (0..=max_n).map(convolution_product).collect() + } + /// Given a BCH generator polynomial, find an element alpha that maximizes the /// consecutive range i..j such that `alpha^i `through `alpha^j` are all roots /// of the polynomial. @@ -456,4 +493,40 @@ mod tests { panic!("Unexpected generator {}", elem); } } + + #[test] + fn mul_mod() { + let x_minus_1: Polynomial<_> = [Fe32::P, Fe32::P].iter().copied().collect(); + assert_eq!( + x_minus_1.mul_mod_x_d(&x_minus_1, 3), + [Fe32::P, Fe32::Q, Fe32::P].iter().copied().collect(), + ); + assert_eq!(x_minus_1.mul_mod_x_d(&x_minus_1, 2), [Fe32::P].iter().copied().collect(),); + } + + #[test] + #[cfg(feature = "alloc")] // needed since `mul_mod_x_d` produces extra 0 coefficients + fn factor_then_mul() { + let bech32_poly: Polynomial = { + use Fe32 as F; + [F::J, F::A, F::_4, F::_5, F::K, F::A, F::P] + } + .iter() + .copied() + .collect(); + + let bech32_poly_lift = Polynomial { inner: bech32_poly.inner.lift() }; + + let factors = bech32_poly + .find_nonzero_distinct_roots(Fe1024::GENERATOR) + .map(|idx| Fe1024::GENERATOR.powi(idx as i64)) + .map(|root| [root, Fe1024::ONE].iter().copied().collect::>()) + .collect::>(); + + let product = factors.iter().fold( + Polynomial::with_monic_leading_term(&[]), + |acc: Polynomial<_>, factor: &Polynomial<_>| acc.mul_mod_x_d(factor, 100), + ); + assert_eq!(bech32_poly_lift, product); + } }