diff --git a/precondition/tearfree/sketchy.py b/precondition/tearfree/sketchy.py index 2912781..48a2c79 100644 --- a/precondition/tearfree/sketchy.py +++ b/precondition/tearfree/sketchy.py @@ -449,15 +449,17 @@ def _all_nan(y): log_ranks = jnp.log(jnp.arange(k + 1, d + 1)) fitted_vals = slope * log_ranks + intercept tail = jnp.exp(jax.scipy.special.logsumexp(fitted_vals * 2)) / (d - k) + undeflated = jnp.square(jnp.maximum(top_eigs, 0.0)) else: tail = axis_state.tail * decay + cutoff**2 - # Avoid numerical error from the sqrt computation and from subtracting - # and re-adding cutoff^2 (mathematically, undeflated == deflated^2 + tail). - undeflated = jnp.square(jnp.maximum(top_eigs, 0.0)) + axis_state.tail * decay + # Avoid numerical error from the sqrt computation and from subtracting + # and re-adding cutoff^2 (mathematically, undeflated == deflated^2 + tail). + undeflated = ( + jnp.square(jnp.maximum(top_eigs, 0.0)) + axis_state.tail * decay + ) eigvecs = u[:, :k] mask = deflated > 0 - # Would be nice to statically assert deflated == 0 implies undeflated == 0. alpha = jnp.asarray(-1.0 / (2 * update.ndim), dtype=jnp.float32) eigvecs *= mask