diff --git a/tests/distributions/test_gamma.py b/tests/distributions/test_gamma.py index be71146a..f4b06196 100644 --- a/tests/distributions/test_gamma.py +++ b/tests/distributions/test_gamma.py @@ -1034,8 +1034,8 @@ def test_serialization(X): d2 = torch.load(".pytest.torch") os.system("rm .pytest.torch") - assert_array_almost_equal(d2.rates, rates) - assert_array_almost_equal(d2._log_rates, numpy.log(rates)) + assert_array_almost_equal(d2.rates, rates, 4) + assert_array_almost_equal(d2._log_rates, numpy.log(rates), 4) assert_array_almost_equal(d2._w_sum, [3., 3., 3.]) assert_array_almost_equal(d2._xw_sum, [11. , 4.2, 4.4])