diff --git a/include/albatross/CovarianceFunctions b/include/albatross/CovarianceFunctions index 7e6d73be..9de92674 100644 --- a/include/albatross/CovarianceFunctions +++ b/include/albatross/CovarianceFunctions @@ -15,6 +15,8 @@ #include "Indexing" +#include + #include #include #include diff --git a/include/albatross/src/covariance_functions/radial.hpp b/include/albatross/src/covariance_functions/radial.hpp index 3a8ae9bc..27a4de9c 100644 --- a/include/albatross/src/covariance_functions/radial.hpp +++ b/include/albatross/src/covariance_functions/radial.hpp @@ -15,6 +15,7 @@ constexpr double default_length_scale = 100000.; constexpr double default_radial_sigma = 10.; +constexpr double default_nu_matern = 2.5; namespace albatross { @@ -232,5 +233,54 @@ class Matern52 : public CovarianceFunction> { DistanceMetricType distance_metric_; }; +inline double matern_covariance(double distance, double length_scale, double nu, + double sigma = 1.) { + if (length_scale <= 0.) { + return 0; + } + if (distance == 0.) { + return sigma * sigma; + } + assert(nu >= 0); + const double m = 2 * std::sqrt(nu) * distance / length_scale; + return sigma * sigma * std::pow(2, 1 - nu) / std::tgamma(nu) * + std::pow(m, nu) * boost::math::cyl_bessel_k(nu, m); +} + +template +class Matern : public CovarianceFunction> { + public: + // The Matern nu = 5/2 radial function is not positive definite + // when the distance is an angular (or great circle) distance. + static_assert(!std::is_base_of::value, + "Matern covariance with AngularDistance is not PSD."); + + ALBATROSS_DECLARE_PARAMS(matern_length_scale, sigma_matern, nu_matern); + + Matern(double length_scale_ = default_length_scale, + double sigma_matern_ = default_radial_sigma, + double nu_matern_ = default_nu_matern) + : distance_metric_() { + matern_length_scale = {length_scale_, PositivePrior()}; + sigma_matern = {sigma_matern_, NonNegativePrior()}; + nu_matern = {nu_matern_, PositivePrior()}; + }; + + std::string name() const { + return "matern[" + this->distance_metric_.get_name() + "]"; + } + + template ::value, + int>::type = 0> + double _call_impl(const X &x, const X &y) const { + double distance = this->distance_metric_(x, y); + return matern_covariance(distance, matern_length_scale.value, + sigma_matern.value); + } + + DistanceMetricType distance_metric_; +}; } // namespace albatross #endif diff --git a/tests/test_radial.cc b/tests/test_radial.cc index 5010c49c..902f9799 100644 --- a/tests/test_radial.cc +++ b/tests/test_radial.cc @@ -386,4 +386,31 @@ TEST(test_radial, test_matern_32_oracle) { } } +TEST(test_radial, test_matern_peak) { + constexpr std::size_t test_iters = 1000; + std::mt19937 gen{22}; + std::normal_distribution<> d{0., 10.}; + for (std::size_t iter = 0; iter < test_iters; ++iter) { + const double x = d(gen); + const double length = 1e-6 + fabs(d(gen)); + const double nu = 1e-6 + fabs(d(gen)); + const Matern cov(length, 1., nu); + EXPECT_EQ(cov(x, x), 1.0); + } +} + +TEST(test_radial, test_matern_off_peak) { + constexpr std::size_t test_iters = 10000000; + std::mt19937 gen{22}; + std::normal_distribution<> d{0., 10.}; + for (std::size_t iter = 0; iter < test_iters; ++iter) { + const double x = d(gen); + const double delta = 1e-6 + d(gen); + const double length = 1e-6 + fabs(d(gen)); + const double nu = 1e-6 + fabs(d(gen)); + const Matern cov(length, 1., nu); + EXPECT_LT(cov(x, x + delta), 1.0); + } +} + } // namespace albatross