Skip to content

Commit

Permalink
fix float64 condition
Browse files Browse the repository at this point in the history
  • Loading branch information
alantian committed Jun 15, 2022
1 parent 9759bd4 commit e33cdcb
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions evojax/algo/cma_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,7 @@ def __init__(
f"In this case, mean (whose shape is {mean.shape}) must have a dimension of (param_size, )" \
f" (i.e. {(param_size, )}), which is not true."
mean = ensure_jnp(mean)
dtype = mean.dtype
mean_max = ensure_jnp(_MEAN_MAX_X64 if dtype == jnp.float64 else _MEAN_MAX_X32)
mean_max = ensure_jnp(_MEAN_MAX_X64 if jax.config.jax_enable_x64 else _MEAN_MAX_X32)
assert jnp.all(
jnp.abs(mean) < mean_max
), f"Abs of all elements of mean vector must be less than {mean_max}"
Expand Down Expand Up @@ -208,7 +207,7 @@ def __init__(
1.0 / (21.0 * (n_dim ** 2))
),
weights=weights,
sigma_max=ensure_jnp(_SIGMA_MAX_X64 if dtype == jnp.float64 else _SIGMA_MAX_X32),
sigma_max=ensure_jnp(_SIGMA_MAX_X64 if jax.config.jax_enable_x64 else _SIGMA_MAX_X32),
)

# evolution path (state)
Expand Down

0 comments on commit e33cdcb

Please sign in to comment.