diff --git a/gtsam/discrete/DiscreteFactor.cpp b/gtsam/discrete/DiscreteFactor.cpp index 68cc1df7da..faae02af23 100644 --- a/gtsam/discrete/DiscreteFactor.cpp +++ b/gtsam/discrete/DiscreteFactor.cpp @@ -71,49 +71,12 @@ AlgebraicDecisionTree DiscreteFactor::errorTree() const { return AlgebraicDecisionTree(dkeys, errors); } -/* ************************************************************************* */ -std::vector expNormalize(const std::vector& logProbs) { - double maxLogProb = -std::numeric_limits::infinity(); - for (size_t i = 0; i < logProbs.size(); i++) { - double logProb = logProbs[i]; - if ((logProb != std::numeric_limits::infinity()) && - logProb > maxLogProb) { - maxLogProb = logProb; - } - } - - // After computing the max = "Z" of the log probabilities L_i, we compute - // the log of the normalizing constant, log S, where S = sum_j exp(L_j - Z). - double total = 0.0; - for (size_t i = 0; i < logProbs.size(); i++) { - double probPrime = exp(logProbs[i] - maxLogProb); - total += probPrime; - } - double logTotal = log(total); - - // Now we compute the (normalized) probability (for each i): - // p_i = exp(L_i - Z - log S) - double checkNormalization = 0.0; - std::vector probs; - for (size_t i = 0; i < logProbs.size(); i++) { - double prob = exp(logProbs[i] - maxLogProb - logTotal); - probs.push_back(prob); - checkNormalization += prob; - } - - // Numerical tolerance for floating point comparisons - double tol = 1e-9; - - if (!gtsam::fpEqual(checkNormalization, 1.0, tol)) { - std::string errMsg = - std::string("expNormalize failed to normalize probabilities. ") + - std::string("Expected normalization constant = 1.0. Got value: ") + - std::to_string(checkNormalization) + - std::string( - "\n This could have resulted from numerical overflow/underflow."); - throw std::logic_error(errMsg); - } - return probs; +/* ************************************************************************ */ +DiscreteFactor::shared_ptr DiscreteFactor::scale() const { + // Max over all the potentials by pretending all keys are frontal: + shared_ptr denominator = this->max(this->size()); + // Normalize the product factor to prevent underflow. + return this->operator/(denominator); } } // namespace gtsam diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index 6cbc00d090..fafb4dbf59 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -158,6 +158,14 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { /// Create new factor by maximizing over all values with the same separator. virtual DiscreteFactor::shared_ptr max(const Ordering& keys) const = 0; + /** + * @brief Scale the factor values by the maximum + * to prevent underflow/overflow. + * + * @return DiscreteFactor::shared_ptr + */ + DiscreteFactor::shared_ptr scale() const; + /** * Get the number of non-zero values contained in this factor. * It could be much smaller than `prod_{key}(cardinality(key))`. @@ -212,22 +220,4 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { template <> struct traits : public Testable {}; -/** - * @brief Normalize a set of log probabilities. - * - * Normalizing a set of log probabilities in a numerically stable way is - * tricky. To avoid overflow/underflow issues, we compute the largest - * (finite) log probability and subtract it from each log probability before - * normalizing. This comes from the observation that if: - * p_i = exp(L_i) / ( sum_j exp(L_j) ), - * Then, - * p_i = exp(Z) exp(L_i - Z) / (exp(Z) sum_j exp(L_j - Z)), - * = exp(L_i - Z) / ( sum_j exp(L_j - Z) ) - * - * Setting Z = max_j L_j, we can avoid numerical issues that arise when all - * of the (unnormalized) log probabilities are either very large or very - * small. - */ -std::vector expNormalize(const std::vector& logProbs); - } // namespace gtsam diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index 7e059c5e5d..f2bae4b9bf 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -125,11 +125,8 @@ namespace gtsam { DiscreteFactor::shared_ptr product = this->product(); gttoc(product); - // Max over all the potentials by pretending all keys are frontal: - auto denominator = product->max(product->size()); - // Normalize the product factor to prevent underflow. - product = product->operator/(denominator); + product = product->scale(); return product; } @@ -217,13 +214,26 @@ namespace gtsam { std::pair // EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { - DiscreteFactor::shared_ptr product = factors.scaledProduct(); + gttic(product); + // `product` is scaled later to prevent underflow. + DiscreteFactor::shared_ptr product = factors.product(); + gttoc(product); // sum out frontals, this is the factor on the separator gttic(sum); DiscreteFactor::shared_ptr sum = product->sum(frontalKeys); gttoc(sum); + // Normalize/scale to prevent underflow. + // We divide both `product` and `sum` by `max(sum)` + // since it is faster to compute and when the conditional + // is formed by `product/sum`, the scaling term cancels out. + gttic(scale); + DiscreteFactor::shared_ptr denominator = sum->max(sum->size()); + product = product->operator/(denominator); + sum = sum->operator/(denominator); + gttoc(scale); + // Ordering keys for the conditional so that frontalKeys are really in front Ordering orderedKeys; orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(),